FX - Audio Effect Base Class#
The FX class is the abstract base class for all audio effects and filters in TorchFX. It defines the interface that all effects must implement and provides the foundation for building custom audio processors.
What is FX?#
FX inherits from torch.nn.Module, making all TorchFX effects compatible with PyTorch’s neural network ecosystem. This design choice enables:
Gradient Computation: Effects can be used in differentiable audio processing pipelines
Parameter Management: Automatic device handling and parameter tracking
Modularity: Effects can be combined with neural networks
Composability: Easy integration with
torch.nn.Sequentialand other PyTorch containers
classDiagram
class Module {
<<PyTorch>>
+forward(x)*
+to(device)
+parameters()
+train() / eval()
}
class FX {
<<abstract>>
+forward(x) Tensor*
+__init__()*
}
class Filter {
<<abstract>>
+int fs
+compute_coefficients()*
}
class Effect {
+forward(x) Tensor
}
class IIRFilter {
+Tensor b
+Tensor a
+compute_coefficients()
+forward(x) Tensor
}
class Delay {
+int delay_samples
+float feedback
+forward(x) Tensor
}
class Reverb {
+int delay
+float decay
+forward(x) Tensor
}
Module <|-- FX
FX <|-- Filter
FX <|-- Effect
Filter <|-- IIRFilter
Effect <|-- Delay
Effect <|-- Reverb
note for Module "PyTorch base class\nprovides device management,\nparameter tracking"
note for FX "TorchFX base class\ndefines audio effect interface"
The FX Interface#
Every effect must implement two key methods:
1. Constructor (__init__)#
Initializes the effect with its parameters:
from torchfx.effect import FX
import torch
class MyEffect(FX):
def __init__(self, param1: float, param2: int):
super().__init__() # Always call parent constructor
self.param1 = param1
self.param2 = param2
# Initialize any learnable parameters
self.gain = torch.nn.Parameter(torch.tensor([1.0]))
Key points:
Always call
super().__init__()firstStore effect parameters as attributes
Use
torch.nn.Parameterfor learnable parameters (if needed)
2. Forward Method (forward)#
Processes the audio signal:
from torch import Tensor
import torch
class MyEffect(FX):
def __init__(self, gain: float):
super().__init__()
self.gain = gain
@torch.no_grad() # Disable gradients for efficiency (optional)
def forward(self, x: Tensor) -> Tensor:
"""Apply effect to audio tensor.
Parameters
----------
x : Tensor
Input audio of shape (..., time) or (channels, time)
Returns
-------
Tensor
Processed audio with the same shape as input
"""
return x * self.gain
Key points:
Input shape:
(channels, samples)or(..., samples)Output shape: Should match input shape (unless explicitly extending signal)
Use
@torch.no_grad()for efficiency if gradients aren’t neededHandle multi-channel audio appropriately
Built-in Effects#
TorchFX provides several built-in effects demonstrating different patterns:
Simple Effects#
Effects with straightforward signal processing:
import torchfx as fx
# Gain adjustment
wave = fx.Wave.from_file("audio.wav")
louder = wave | fx.effect.Gain(gain=2.0, gain_type="amplitude")
# Normalization
normalized = wave | fx.effect.Normalize(peak=0.9)
Time-Based Effects#
Effects that use delay and feedback:
import torchfx as fx
wave = fx.Wave.from_file("audio.wav")
# Simple reverb
reverb = fx.effect.Reverb(delay=4410, decay=0.5, mix=0.3)
wet = wave | reverb
# BPM-synced delay
delay = fx.effect.Delay(bpm=120, delay_time="1/8", feedback=0.4, mix=0.3)
delayed = wave | delay
Strategy Pattern Effects#
Effects using the Strategy Pattern for flexible behavior:
import torchfx as fx
from torchfx.effect import (
Normalize,
RMSNormalizationStrategy,
PercentileNormalizationStrategy
)
wave = fx.Wave.from_file("audio.wav")
# Peak normalization (default)
norm1 = wave | Normalize(peak=1.0)
# RMS normalization
norm2 = wave | Normalize(peak=0.5, strategy=RMSNormalizationStrategy())
# Percentile normalization
norm3 = wave | Normalize(peak=0.9, strategy=PercentileNormalizationStrategy(percentile=99))
See also
Gamma et al. [1994] - Design Patterns book covering the Strategy Pattern
Filters vs Effects#
TorchFX distinguishes between filters and effects:
Filters#
Inherit from AbstractFilter (which inherits from FX):
Frequency-domain processing: IIR, FIR filters
Require sample rate:
fsattribute is mandatoryCompute coefficients: Must implement
compute_coefficients()methodParallel combination: Support
+operator for parallel filter banks
from torchfx.filter import iir
import torchfx as fx
wave = fx.Wave.from_file("audio.wav")
# IIR filters with automatic fs configuration
lowpass = wave | iir.LoButterworth(cutoff=1000, order=4)
highpass = wave | iir.HiButterworth(cutoff=200, order=2)
# Parallel combination (bandpass filter)
bandpass = wave | (iir.HiButterworth(200) + iir.LoButterworth(1000))
Effects#
Inherit directly from FX:
Time-domain processing: Delay, reverb, dynamics, etc.
Optional sample rate: May or may not need
fsDirect implementation: No coefficient computation required
Flexible parameters: Can use any processing strategy
import torchfx as fx
wave = fx.Wave.from_file("audio.wav")
# Effects don't require fs (unless BPM-synced)
gained = wave | fx.effect.Gain(2.0)
normalized = wave | fx.effect.Normalize(peak=0.8)
# BPM-synced effects auto-configure fs from Wave
delayed = wave | fx.effect.Delay(bpm=120, delay_time="1/4")
Creating Custom Effects#
Basic Custom Effect#
from torchfx.effect import FX
from torch import Tensor
import torch
class SimpleDistortion(FX):
"""Apply soft clipping distortion."""
def __init__(self, drive: float = 2.0, mix: float = 0.5):
"""
Parameters
----------
drive : float
Amount of distortion (>1.0). Higher values = more distortion.
mix : float
Wet/dry mix (0 = dry, 1 = wet).
"""
super().__init__()
assert drive >= 1.0, "Drive must be >= 1.0"
assert 0 <= mix <= 1, "Mix must be in [0, 1]"
self.drive = drive
self.mix = mix
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
"""Apply soft clipping distortion.
Uses tanh for smooth saturation:
y[n] = tanh(drive * x[n])
"""
# Apply distortion
distorted = torch.tanh(self.drive * x)
# Wet/dry mix
output = (1 - self.mix) * x + self.mix * distorted
return output
Usage:
import torchfx as fx
wave = fx.Wave.from_file("guitar.wav")
distorted = wave | SimpleDistortion(drive=3.0, mix=0.7)
distorted.save("guitar_distorted.wav")
Multi-Channel Effect#
Handle stereo and multi-channel audio correctly:
from torchfx.effect import FX
from torch import Tensor
import torch
class StereoWidener(FX):
"""Widen stereo image using Mid/Side processing."""
def __init__(self, width: float = 1.5):
"""
Parameters
----------
width : float
Stereo width multiplier (1.0 = no change, >1.0 = wider, <1.0 = narrower).
"""
super().__init__()
assert width >= 0, "Width must be non-negative"
self.width = width
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
"""Apply stereo widening.
Converts to Mid/Side, scales Side, converts back to L/R.
"""
# Only works on stereo audio
if x.shape[0] != 2:
return x # Return unchanged for non-stereo
left = x[0]
right = x[1]
# Convert to Mid/Side
mid = (left + right) / 2
side = (left - right) / 2
# Widen by scaling Side component
side = side * self.width
# Convert back to L/R
new_left = mid + side
new_right = mid - side
return torch.stack([new_left, new_right])
Effect with Strategy Pattern#
Use the Strategy Pattern for flexible behavior:
from torchfx.effect import FX
from torch import Tensor
import torch
import abc
class CompressionStrategy(abc.ABC):
"""Abstract base for compression algorithms."""
@abc.abstractmethod
def compress(self, x: Tensor, threshold: float, ratio: float) -> Tensor:
pass
class HardKneeCompression(CompressionStrategy):
"""Hard-knee compression with sharp threshold."""
def compress(self, x: Tensor, threshold: float, ratio: float) -> Tensor:
abs_x = torch.abs(x)
mask = abs_x > threshold
# Compress values above threshold
compressed = torch.where(
mask,
threshold + (abs_x - threshold) / ratio,
abs_x
)
# Restore sign
return torch.sign(x) * compressed
class SoftKneeCompression(CompressionStrategy):
"""Soft-knee compression with gradual transition."""
def compress(self, x: Tensor, threshold: float, ratio: float) -> Tensor:
# Implementation of soft-knee compression
# (simplified for brevity)
return x # Placeholder
class Compressor(FX):
"""Dynamic range compressor with configurable strategy."""
def __init__(
self,
threshold: float = 0.5,
ratio: float = 4.0,
strategy: CompressionStrategy | None = None
):
super().__init__()
self.threshold = threshold
self.ratio = ratio
self.strategy = strategy or HardKneeCompression()
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
return self.strategy.compress(x, self.threshold, self.ratio)
Usage:
import torchfx as fx
wave = fx.Wave.from_file("vocals.wav")
# Hard knee compression (default)
compressed1 = wave | Compressor(threshold=0.5, ratio=4.0)
# Soft knee compression
compressed2 = wave | Compressor(
threshold=0.5,
ratio=4.0,
strategy=SoftKneeCompression()
)
Sample Rate Handling#
Many effects need the sample rate to function correctly. TorchFX provides automatic configuration:
Automatic Configuration#
When using the pipeline operator with Wave, the sample rate is automatically set:
import torchfx as fx
wave = fx.Wave.from_file("audio.wav") # fs = 44100
# Effect's fs is automatically set to 44100
delayed = wave | fx.effect.Delay(bpm=120, delay_time="1/8")
Manual Configuration#
For standalone use or custom effects:
from torchfx.effect import FX
from torch import Tensor
import torch
class MyTimedEffect(FX):
"""Effect that needs sample rate."""
def __init__(self, duration_ms: float, fs: int | None = None):
super().__init__()
self.fs = fs
self.duration_ms = duration_ms
self._duration_samples = None
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
# Lazy calculation when fs becomes available
if self._duration_samples is None:
assert self.fs is not None, "Sample rate (fs) must be set"
self._duration_samples = int(self.duration_ms * self.fs / 1000)
# Use self._duration_samples for processing
return x # Placeholder
Gradient Support#
While most audio effects run with @torch.no_grad() for efficiency, you can enable gradients for differentiable audio processing:
from torchfx.effect import FX
from torch import Tensor
import torch
class LearnableGain(FX):
"""Gain effect with learnable parameter."""
def __init__(self, initial_gain: float = 1.0):
super().__init__()
# Learnable parameter
self.gain = torch.nn.Parameter(torch.tensor([initial_gain]))
def forward(self, x: Tensor) -> Tensor:
"""Forward with gradient support."""
return x * self.gain
# Usage in a differentiable pipeline
effect = LearnableGain(initial_gain=0.5)
optimizer = torch.optim.Adam(effect.parameters(), lr=0.01)
# Training loop
for epoch in range(100):
output = effect(input_audio)
loss = some_loss_function(output, target_audio)
optimizer.zero_grad()
loss.backward()
optimizer.step()
See also
Engel et al. [2020] - Differentiable Digital Signal Processing
Best Practices#
Parameter Validation#
Always validate parameters in the constructor:
class MyEffect(FX):
def __init__(self, param: float):
super().__init__()
assert param > 0, "Parameter must be positive"
assert param <= 1.0, "Parameter must be <= 1.0"
self.param = param
Handle Edge Cases#
Consider boundary conditions:
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
# Check for empty input
if x.numel() == 0:
return x
# Check for very short signals
if x.shape[-1] < self.required_length:
return x # Or pad, or raise error
# Process normally
return processed
Preserve Tensor Properties#
Maintain dtype and device:
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
# Create new tensors on same device with same dtype
buffer = torch.zeros_like(x)
# Or explicitly specify
buffer = torch.zeros(
x.shape,
dtype=x.dtype,
device=x.device
)
return buffer
Document Mathematical Formulation#
Include formulas in docstrings:
class MyEffect(FX):
r"""Apply custom effect.
The effect is computed as:
.. math::
y[n] = \alpha x[n] + (1-\alpha) x[n-1]
where:
- x[n] is the input signal
- y[n] is the output signal
- \alpha is the blend factor
"""
Common Patterns#
Wet/Dry Mix#
Almost all effects benefit from a mix parameter:
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
# Process signal
processed = self.process(x)
# Mix with dry signal
output = (1 - self.mix) * x + self.mix * processed
return output
Extend Signal Length#
For delay-based effects:
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
original_length = x.shape[-1]
extended_length = original_length + self.delay_samples
# Create extended buffer
output = torch.zeros(
*x.shape[:-1], extended_length,
dtype=x.dtype, device=x.device
)
# Copy original signal
output[..., :original_length] = x
# Add delayed signal
output[..., self.delay_samples:] += x * self.feedback
return output
Per-Channel Processing#
Use torch.nn.ModuleList for per-channel effects:
class PerChannelEffect(FX):
def __init__(self, num_channels: int):
super().__init__()
self.processors = torch.nn.ModuleList([
ChannelProcessor() for _ in range(num_channels)
])
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
# x shape: (channels, samples)
outputs = []
for ch in range(x.shape[0]):
processed = self.processors[ch](x[ch:ch+1])
outputs.append(processed)
return torch.cat(outputs, dim=0)
External Resources#
PyTorch nn.Module Documentation - Understanding the base class
Abstract Base Classes in Python - Using ABC for interfaces
DAFX Book - Digital Audio Effects theory and algorithms
References#
Jesse Engel, Lamtharn (Hanoi) Hantrakul, Chenjie Gu, and Adam Roberts. Ddsp: differentiable digital signal processing. In International Conference on Learning Representations. 2020. URL: https://openreview.net/forum?id=B1x1ma4tDr.
Erich Gamma, Richard Helm, Ralph Johnson, and John Vlissides. Design Patterns: Elements of Reusable Object-Oriented Software. Addison-Wesley Professional, 1994. ISBN 978-0201633610. Classic reference for software design patterns including Strategy pattern.