TorchFX 0.5.0: Custom CUDA Kernels & Native C++ Extension#

I’m excited to announce TorchFX 0.5.0, a performance-focused release that introduces custom CUDA kernels, a JIT-compiled C++ native extension, and major algorithmic improvements across the entire filter pipeline.

This release delivers on the Phase 3 optimization goals outlined in the 0.4.0 roadmap.

The Native Extension (torchfx._ops)#

At the core of 0.5.0 is a new JIT-compiled C++/CUDA extension that loads automatically when you import TorchFX. The extension is compiled on first use via torch.utils.cpp_extension and cached for subsequent imports.

import torchfx
# [torchfx] native extension: YES
# [torchfx] CUDA available: True (NVIDIA RTX 6000)

Key design decisions:

  • Automatic fallback: If compilation fails (no compiler, no CUDA toolkit), TorchFX falls back to pure-PyTorch paths transparently. Your code doesn’t change.

  • CPU-only support: The C++ extension compiles and loads without the CUDA toolkit. You get native-speed IIR filtering on CPU even without a GPU.

  • Environment control: Set TORCHFX_NO_CUDA=1 to force CPU-only compilation if you want to skip CUDA entirely.

Compiler Requirements#

To compile the native extension, you need GCC 9 or newer (or an equivalent C++17-compatible compiler). The CPU extension compiles with -O3 -ffast-math -march=native and OpenMP parallelization for multi-channel workloads. On most Linux systems with a recent toolchain this works out of the box.

CUDA Parallel Scan for IIR Filters#

IIR (Infinite Impulse Response) filters have a fundamental challenge on GPUs: each output sample depends on previous outputs, creating a sequential dependency chain. The naive approach — one thread per channel, looping over samples — leaves 99% of the GPU idle.

TorchFX 0.5.0 solves this with a parallel prefix scan (Blelloch scan) that decomposes the IIR recurrence into parallel-friendly operations:

  • O(N) total work instead of O(N log N) from the previous Hillis-Steele approach

  • 24 KB shared memory per block, down from 48 KB, allowing higher occupancy

  • 128 channels batched per thread block for the sequential biquad kernel, improving GPU utilization on short signals

The result is that a 300-second, 12-channel IIR cascade completes in 550 ms on GPU — compared to 5.4 seconds with SciPy and 1.1 seconds on CPU.

FFT-Based FIR Convolution#

FIR filters now default to FFT convolution via the overlap-save method, adapted from Julius (MIT License). For kernel sizes >= 64 taps, this is up to 10x faster than direct convolution.

You can control the convolution mode per filter:

from torchfx.filter import DesignableFIR

# FFT convolution (default, fast for large kernels)
fir = DesignableFIR(num_taps=512, cutoff=4000, fs=44100, conv_mode="fft")

# Direct convolution (better for very small kernels)
fir = DesignableFIR(num_taps=16, cutoff=4000, fs=44100, conv_mode="direct")

# Automatic selection based on kernel size
fir = DesignableFIR(num_taps=128, cutoff=4000, fs=44100, conv_mode="auto")

LogFilterBank#

A new LogFilterBank class provides logarithmically-spaced frequency band decomposition, useful for spectral analysis, multiband processing, and feature extraction:

from torchfx.filter import LogFilterBank

bank = LogFilterBank(n_bands=32, f_low=20, f_high=20000, fs=44100)
bands = bank(wave.ys)  # [n_bands, channels, samples]

Performance-Optimized Fallback Paths#

Even without the native extension, 0.5.0 is dramatically faster than 0.4.0. The pure-PyTorch fallback paths have been completely rewritten:

  • Stateful biquad and IIR SOS: Replaced sample-by-sample Python loops with a vectorized zero-state/zero-input decomposition using lfilter. This gives a 100-500x speedup when the C++ extension is unavailable.

  • Eager SOS computation: The SOS matrix is now computed immediately after compute_coefficients() instead of lazily during forward(), avoiding repeated work.

  • Pre-computed constant tensors: SOS convolution kernels and delta tensors are cached to avoid per-call allocation.

  • Eliminated redundant device transfers: State tensor .to(device) calls are now guarded to skip when already on the correct device.

GPU Kernel Improvements#

Beyond the parallel scan, several targeted GPU optimizations:

  • Removed synchronous CUDA calls from native kernels, improving throughput by avoiding unnecessary CPU-GPU synchronization points.

  • Scalar coefficient passing: Biquad coefficients (b0, b1, b2, a1, a2) are now passed as scalar arguments to CUDA kernels instead of being extracted from device tensors, eliminating a GPU-to-CPU synchronization that was causing a segfault on some configurations.

Benchmark Infrastructure#

The benchmark suite has been migrated from standalone scripts to a pytest-benchmark suite under benchmarks/:

# Run all benchmarks
uv run pytest --benchmark-enable

# Run only IIR benchmarks
uv run pytest benchmarks/test_iir_bench.py --benchmark-enable

# Run only FIR benchmarks
uv run pytest benchmarks/test_fir_bench.py --benchmark-enable

Each benchmark compares five backends: TorchFX GPU (CUDA), TorchFX CPU, SciPy, Numba @njit (CPU), and Numba @cuda.jit (GPU), across signal durations from 1 to 300 seconds and varying channel counts / filter orders.

Bug Fixes#

  • Fixed native extension being unreachable on CPU-only machines due to an overly strict torch.cuda.is_available() gate in _ops.py.

  • Fixed segfault in the CUDA biquad kernel caused by dereferencing a device pointer on the host.

Benchmark Results (RTX 6000)#

Here’s a snapshot from our CI benchmarks on a Quadro RTX 6000 (24 GB):

Backend

300s / order 12 IIR

Relative

TorchFX GPU

550 ms

1.0x

TorchFX CPU

1,086 ms

2.0x slower

Numba @njit CPU

TBD

–

SciPy

5,428 ms

9.9x slower

Numba @cuda.jit

12,957 ms

23.6x slower

The GPU kernel maintains sub-millisecond standard deviation, making it suitable for latency-sensitive workloads.

Installation#

pip install torchfx

The native extension compiles automatically on first import. Ensure you have:

  • GCC >= 9 (or equivalent C++17 compiler)

  • PyTorch >= 2.0 with matching CUDA toolkit (for GPU kernels)

  • setuptools (now a runtime dependency, required by torch.utils.cpp_extension)

For CPU-only builds:

TORCHFX_NO_CUDA=1 pip install torchfx

What’s Next#

With the performance foundation in place, we’re turning our attention to:

  • Additional effects: compressor, phaser, pitch shift

  • Batch processing optimizations for the CLI pipeline

  • v1.0.0 release candidate with API stability guarantees

Benchmarks where run on a Quadro RTX 6000 (24 GB) with CUDA 12.1 and PyTorch 2.10.0. Performance may vary based on hardware and software configuration. Always benchmark on your target system for best results.