FP32 on the GPU: 3–3.6× and the End of the Consumer-GPU Penalty#
This is the GPU half of the promise we made in 0.5.4: “retuning the CUDA SOS kernel for mixed precision so float32 gets the same fast path on GPU that it now has on CPU.” TorchFX 0.6.0 delivers it.
The problem: a double-only kernel on a float32 world#
The CUDA biquad and SOS parallel-scan kernels were written entirely in double. The 3×3 state-transition matrices, the forcing function, every phase of the Blelloch scan — all double. The Python dispatch layer enforced it: any CUDA input was upcast to float64 before the kernel ran.
That is the wrong default for two reasons:
Most audio is
float32. Realtime callbacks, ML feature pipelines, and game audio all carryfloat32. Forcing them throughfloat64doubles the memory bandwidth for no accuracy gain at the filter orders TorchFX targets.Consumer GPUs hate FP64. An RTX 3070 (and the A40) has a 1:32 FP32:FP64 throughput ratio — FP64 runs at one-thirty-second of peak. The kernel was leaving most of the card on the floor.
The symptom was stark: on a 60-second, 8-channel cascade the GPU was slower than its own CPU.
The fix: template on scalar_t, dispatch on the input#
The kernels are now templated on scalar_t (the same pattern the CPU kernels already used). The Mat3x3 state matrix, the forcing kernel, all three Blelloch phases, and the sequential fallback are instantiated for both float and double. The two host entry points dispatch on the input tensor’s dtype via AT_DISPATCH_FLOATING_TYPES, and the block-aggregate scratch is allocated in the input dtype so the FP32 path stays FP32 end to end. The scalar coefficients arrive as double and are cast to scalar_t at launch.
The dispatch rule is now simple and symmetric with the CPU:
The native execution dtype follows the input.
float32in → FP32 kernels.float64in → FP64 kernels. Pass the dtype you want.
No silent conversions in either direction. Half precision (float16 / bfloat16) is rejected with a clear error — the IIR feedback recurrence is not numerically safe there.
The numbers#
8th-order Butterworth (4 SOS sections) @ 48 kHz, RTX 3070, median over 30 iterations (benchmarks/bench_fp32_speedup.py):
Workload |
GPU FP64 |
GPU FP32 |
Speedup |
|---|---|---|---|
30 s / 1 ch |
9.49 ms |
2.80 ms |
3.39× |
60 s / 1 ch |
18.31 ms |
6.00 ms |
3.05× |
60 s / 2 ch |
29.03 ms |
9.32 ms |
3.11× |
60 s / 4 ch |
49.22 ms |
14.41 ms |
3.42× |
60 s / 8 ch |
89.18 ms |
24.49 ms |
3.64× |
A consistent 3.0–3.6×. Note that this is not the theoretical 32× FP32:FP64 ratio — the parallel scan is partly bandwidth- and launch-overhead-bound, not pure-FLOP-bound, so 3–4× is the honest, measured win. (It would be larger again on a datacenter card with a 1:2 ratio, where FP64 was never the bottleneck.)
The inversion, resolved#
The most satisfying row is 8 channels. Here is the full picture for that workload:
60 s / 8 ch |
Time |
Verdict |
|---|---|---|
CPU FP64 |
34.0 ms |
— |
CPU FP32 |
27.6 ms |
— |
GPU FP64 |
89.2 ms |
loses to its own CPU |
GPU FP32 |
24.5 ms |
beats the CPU |
In FP64 the consumer GPU genuinely lost to the OpenMP CPU kernel once the per-step working set widened across channels — an embarrassing inversion. FP32 erases it: the GPU is now the fastest backend everywhere on this card.
Is FP32 safe?#
Lower precision is only a win if it’s still correct. 0.6.0 ships tests/test_fp32_precision.py, which validates both paths against scipy.signal.sosfilt:
The FP64 path matches scipy to ~double precision.
The FP32 path matches the reference within a documented float32 bound (max-abs + RMS-relative), swept across Butterworth and Chebyshev I at orders 2/4/8/16, on both CPU and CUDA.
The harness records the per-design error so the FP32-safe-vs-needs-FP64 boundary is tracked over time. For the well-conditioned audio designs TorchFX ships, FP32 tracks the FP64 reference to float32 precision with no surprises. If you have a pathological high-order, poles-near-the-unit-circle design, pass float64 and you get the precise path — your choice, per call.
Try it#
import torch
from torchfx import Wave
from torchfx.filter import LoButterworth
# float32 in -> FP32 GPU kernels (fast)
wave = Wave(torch.randn(8, 48000 * 60, dtype=torch.float32), fs=48000).to("cuda")
out = wave | LoButterworth(4000, order=8)
# float64 in -> FP64 GPU kernels (precise), same code
wave64 = Wave(torch.randn(8, 48000 * 60, dtype=torch.float64), fs=48000).to("cuda")
out64 = wave64 | LoButterworth(4000, order=8)
python benchmarks/bench_fp32_speedup.py
Back to the 0.6.0 release notes.