TorchFX 0.5.0: Custom CUDA Kernels & Native C++ Extension#
I’m excited to announce TorchFX 0.5.0, a performance-focused release that introduces custom CUDA kernels, a JIT-compiled C++ native extension, and major algorithmic improvements across the entire filter pipeline.
This release delivers on the Phase 3 optimization goals outlined in the 0.4.0 roadmap.
The Native Extension (torchfx._ops)#
At the core of 0.5.0 is a new JIT-compiled C++/CUDA extension that loads automatically when you import TorchFX. The extension is compiled on first use via torch.utils.cpp_extension and cached for subsequent imports.
import torchfx
# [torchfx] native extension: YES
# [torchfx] CUDA available: True (NVIDIA RTX 6000)
Key design decisions:
Automatic fallback: If compilation fails (no compiler, no CUDA toolkit), TorchFX falls back to pure-PyTorch paths transparently. Your code doesn’t change.
CPU-only support: The C++ extension compiles and loads without the CUDA toolkit. You get native-speed IIR filtering on CPU even without a GPU.
Environment control: Set
TORCHFX_NO_CUDA=1to force CPU-only compilation if you want to skip CUDA entirely.
Compiler Requirements#
To compile the native extension, you need GCC 9 or newer (or an equivalent C++17-compatible compiler). The CPU extension compiles with -O3 -ffast-math -march=native and OpenMP parallelization for multi-channel workloads. On most Linux systems with a recent toolchain this works out of the box.
CUDA Parallel Scan for IIR Filters#
IIR (Infinite Impulse Response) filters have a fundamental challenge on GPUs: each output sample depends on previous outputs, creating a sequential dependency chain. The naive approach — one thread per channel, looping over samples — leaves 99% of the GPU idle.
TorchFX 0.5.0 solves this with a parallel prefix scan (Blelloch scan) that decomposes the IIR recurrence into parallel-friendly operations:
O(N) total work instead of O(N log N) from the previous Hillis-Steele approach
24 KB shared memory per block, down from 48 KB, allowing higher occupancy
128 channels batched per thread block for the sequential biquad kernel, improving GPU utilization on short signals
The result is that a 300-second, 12-channel IIR cascade completes in 550 ms on GPU — compared to 5.4 seconds with SciPy and 1.1 seconds on CPU.
FFT-Based FIR Convolution#
FIR filters now default to FFT convolution via the overlap-save method, adapted from Julius (MIT License). For kernel sizes >= 64 taps, this is up to 10x faster than direct convolution.
You can control the convolution mode per filter:
from torchfx.filter import DesignableFIR
# FFT convolution (default, fast for large kernels)
fir = DesignableFIR(num_taps=512, cutoff=4000, fs=44100, conv_mode="fft")
# Direct convolution (better for very small kernels)
fir = DesignableFIR(num_taps=16, cutoff=4000, fs=44100, conv_mode="direct")
# Automatic selection based on kernel size
fir = DesignableFIR(num_taps=128, cutoff=4000, fs=44100, conv_mode="auto")
LogFilterBank#
A new LogFilterBank class provides logarithmically-spaced frequency band decomposition, useful for spectral analysis, multiband processing, and feature extraction:
from torchfx.filter import LogFilterBank
bank = LogFilterBank(n_bands=32, f_low=20, f_high=20000, fs=44100)
bands = bank(wave.ys) # [n_bands, channels, samples]
Performance-Optimized Fallback Paths#
Even without the native extension, 0.5.0 is dramatically faster than 0.4.0. The pure-PyTorch fallback paths have been completely rewritten:
Stateful biquad and IIR SOS: Replaced sample-by-sample Python loops with a vectorized zero-state/zero-input decomposition using
lfilter. This gives a 100-500x speedup when the C++ extension is unavailable.Eager SOS computation: The SOS matrix is now computed immediately after
compute_coefficients()instead of lazily duringforward(), avoiding repeated work.Pre-computed constant tensors: SOS convolution kernels and delta tensors are cached to avoid per-call allocation.
Eliminated redundant device transfers: State tensor
.to(device)calls are now guarded to skip when already on the correct device.
GPU Kernel Improvements#
Beyond the parallel scan, several targeted GPU optimizations:
Removed synchronous CUDA calls from native kernels, improving throughput by avoiding unnecessary CPU-GPU synchronization points.
Scalar coefficient passing: Biquad coefficients (
b0,b1,b2,a1,a2) are now passed as scalar arguments to CUDA kernels instead of being extracted from device tensors, eliminating a GPU-to-CPU synchronization that was causing a segfault on some configurations.
Benchmark Infrastructure#
The benchmark suite has been migrated from standalone scripts to a pytest-benchmark suite under benchmarks/:
# Run all benchmarks
uv run pytest --benchmark-enable
# Run only IIR benchmarks
uv run pytest benchmarks/test_iir_bench.py --benchmark-enable
# Run only FIR benchmarks
uv run pytest benchmarks/test_fir_bench.py --benchmark-enable
Each benchmark compares five backends: TorchFX GPU (CUDA), TorchFX CPU, SciPy, Numba @njit (CPU), and Numba @cuda.jit (GPU), across signal durations from 1 to 300 seconds and varying channel counts / filter orders.
Bug Fixes#
Fixed native extension being unreachable on CPU-only machines due to an overly strict
torch.cuda.is_available()gate in_ops.py.Fixed segfault in the CUDA biquad kernel caused by dereferencing a device pointer on the host.
Benchmark Results (RTX 6000)#
Here’s a snapshot from our CI benchmarks on a Quadro RTX 6000 (24 GB):
Backend |
300s / order 12 IIR |
Relative |
|---|---|---|
TorchFX GPU |
550 ms |
1.0x |
TorchFX CPU |
1,086 ms |
2.0x slower |
Numba |
TBD |
– |
SciPy |
5,428 ms |
9.9x slower |
Numba |
12,957 ms |
23.6x slower |
The GPU kernel maintains sub-millisecond standard deviation, making it suitable for latency-sensitive workloads.
Installation#
pip install torchfx
The native extension compiles automatically on first import. Ensure you have:
GCC >= 9 (or equivalent C++17 compiler)
PyTorch >= 2.0 with matching CUDA toolkit (for GPU kernels)
setuptools(now a runtime dependency, required bytorch.utils.cpp_extension)
For CPU-only builds:
TORCHFX_NO_CUDA=1 pip install torchfx
What’s Next#
With the performance foundation in place, we’re turning our attention to:
Additional effects: compressor, phaser, pitch shift
Batch processing optimizations for the CLI pipeline
v1.0.0 release candidate with API stability guarantees
Benchmarks where run on a Quadro RTX 6000 (24 GB) with CUDA 12.1 and PyTorch 2.10.0. Performance may vary based on hardware and software configuration. Always benchmark on your target system for best results.