Creating Custom Filters#
Learn how to build your own digital filters by extending TorchFX’s AbstractFilter base class. This tutorial covers filter design patterns, coefficient computation, and integration with the TorchFX pipeline system.
Overview#
Custom filters in TorchFX inherit from AbstractFilter, which provides the foundation for:
Pipeline integration: Automatic support for the
|operatorParallel combination: Built-in
+operator for filter banksPyTorch compatibility: Inherits from
torch.nn.ModuleDevice management: Automatic GPU/CPU handling
Sample rate configuration: Automatic
fspropagation fromWave
classDiagram
class Module {
<<PyTorch>>
+forward(x)*
+to(device)
+parameters()
}
class FX {
<<abstract>>
+forward(x) Tensor*
}
class AbstractFilter {
<<abstract>>
+compute_coefficients()*
+__add__(other) ParallelFilterCombination
+_has_computed_coeff bool
}
class IIR {
<<abstract>>
+fs int|None
+a Tensor|None
+b Tensor|None
+move_coeff(device, dtype)
+forward(x) Tensor
}
class CustomFilter {
+__init__(...)
+compute_coefficients()
+forward(x) Tensor
}
Module <|-- FX
FX <|-- AbstractFilter
AbstractFilter <|-- IIR
AbstractFilter <|-- CustomFilter
note for Module "PyTorch base class"
note for AbstractFilter "Adds filter-specific features"
note for CustomFilter "Your implementation"
Required Methods#
Every custom filter must implement these three methods:
1. __init__ - Initialization#
Initialize filter parameters and call parent constructor:
from torchfx.filter import AbstractFilter
class CustomFilter(AbstractFilter):
def __init__(self, param1: float, param2: int, fs: int | None = None):
super().__init__() # REQUIRED: Initialize parent class
# Store filter parameters
self.param1 = param1
self.param2 = param2
self.fs = fs
# Initialize coefficient storage
self.a = None
self.b = None
Key points:
Always call
super().__init__()firstAccept
fsparameter (can beNone)Initialize coefficient attributes to
NoneValidate parameters if needed
2. compute_coefficients - Filter Design#
Compute filter coefficients based on parameters:
def compute_coefficients(self) -> None:
"""Compute filter coefficients."""
# Verify fs is set
if self.fs is None:
raise ValueError("Sample rate must be set before computing coefficients")
# Design filter (example using scipy)
from scipy.signal import butter
# Normalize frequency to Nyquist
nyquist = 0.5 * self.fs
normalized_freq = self.cutoff / nyquist
# Compute coefficients
self.b, self.a = butter(self.order, normalized_freq, btype='low')
Key points:
Check that
fsis notNoneUse SciPy or custom formulas
Set
self.aandself.battributesOnly computed once (cached automatically)
3. forward - Apply Filter#
Process audio through the filter:
import torch
from torch import Tensor
from torchaudio.functional import lfilter
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
"""Apply filter to audio tensor."""
# Compute coefficients if not already done
if self.a is None or self.b is None:
self.compute_coefficients()
# Convert to tensors if needed (move to correct device)
if not isinstance(self.a, Tensor):
self.a = torch.as_tensor(self.a, device=x.device, dtype=x.dtype)
self.b = torch.as_tensor(self.b, device=x.device, dtype=x.dtype)
# Apply IIR filter
return lfilter(x, self.a, self.b)
Key points:
Use
@torch.no_grad()for efficiencyLazy coefficient computation
Handle device/dtype conversion
Use
torchaudio.functional.lfilter()for IIR filters
Complete Example: Custom Bandpass Filter#
Here’s a complete, working example of a custom bandpass filter:
import numpy as np
import torch
from torch import Tensor
from scipy.signal import butter
from torchaudio.functional import lfilter
from torchfx.filter import AbstractFilter
class CustomBandpass(AbstractFilter):
"""Custom bandpass filter using Butterworth design.
Parameters
----------
low_cutoff : float
Lower cutoff frequency in Hz
high_cutoff : float
Upper cutoff frequency in Hz
order : int, optional
Filter order (default: 4)
fs : int, optional
Sample rate in Hz (can be set later)
Examples
--------
>>> import torchfx as fx
>>> wave = fx.Wave.from_file("audio.wav")
>>> bandpass = CustomBandpass(low_cutoff=200, high_cutoff=2000, order=4)
>>> filtered = wave | bandpass
"""
def __init__(
self,
low_cutoff: float,
high_cutoff: float,
order: int = 4,
fs: int | None = None,
):
super().__init__()
# Validate parameters
assert low_cutoff > 0, "Low cutoff must be positive"
assert high_cutoff > low_cutoff, "High cutoff must be > low cutoff"
assert order > 0, "Order must be positive"
# Store parameters
self.low_cutoff = low_cutoff
self.high_cutoff = high_cutoff
self.order = order
self.fs = fs
# Initialize coefficients
self.a = None
self.b = None
def compute_coefficients(self) -> None:
"""Compute Butterworth bandpass coefficients."""
if self.fs is None:
raise ValueError("Sample rate must be set before computing coefficients")
# Normalize frequencies to Nyquist frequency
nyquist = 0.5 * self.fs
low_norm = self.low_cutoff / nyquist
high_norm = self.high_cutoff / nyquist
# Validate normalized frequencies
if not (0 < low_norm < 1 and 0 < high_norm < 1):
raise ValueError(
f"Cutoff frequencies must be between 0 and Nyquist ({nyquist} Hz)"
)
# Design Butterworth bandpass filter
self.b, self.a = butter(
self.order,
[low_norm, high_norm],
btype='bandpass'
)
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
"""Apply bandpass filter to input tensor.
Parameters
----------
x : Tensor
Input audio tensor of shape (channels, samples)
Returns
-------
Tensor
Filtered audio tensor
"""
# Lazy coefficient computation
if self.a is None or self.b is None:
self.compute_coefficients()
# Convert coefficients to tensors (match input device/dtype)
if not isinstance(self.a, Tensor):
self.a = torch.as_tensor(self.a, device=x.device, dtype=x.dtype)
self.b = torch.as_tensor(self.b, device=x.device, dtype=x.dtype)
# Apply filter
return lfilter(x, self.a, self.b)
Using the Custom Filter#
import torchfx as fx
# Load audio
wave = fx.Wave.from_file("audio.wav")
# Create and apply filter (fs auto-configured from wave)
bandpass = CustomBandpass(low_cutoff=200, high_cutoff=2000, order=4)
filtered = wave | bandpass
# Save result
filtered.save("bandpass_filtered.wav")
# Can also chain with other filters
from torchfx.filter import iir
processed = (
wave
| iir.HiButterworth(cutoff=80, order=2) # Remove rumble
| bandpass # Bandpass 200-2000 Hz
| fx.effect.Normalize() # Normalize
)
Filter Design Patterns#
Pattern 1: SciPy-Based Filters#
Use SciPy’s signal processing functions for coefficient design:
from scipy.signal import butter, cheby1, cheby2, ellip, iirnotch, iirpeak
class CustomLowpass(AbstractFilter):
def compute_coefficients(self) -> None:
nyquist = 0.5 * self.fs
norm_freq = self.cutoff / nyquist
# Choose a design function
self.b, self.a = butter(self.order, norm_freq, btype='low')
# self.b, self.a = cheby1(self.order, self.ripple, norm_freq, btype='low')
# self.b, self.a = ellip(self.order, self.ripple, self.atten, norm_freq, btype='low')
Available SciPy filters:
scipy.signal.butter()- Butterworth (maximally flat passband)scipy.signal.cheby1()- Chebyshev Type I (ripple in passband)scipy.signal.cheby2()- Chebyshev Type II (ripple in stopband)scipy.signal.ellip()- Elliptic (ripple in both bands)scipy.signal.iirpeak()- Peaking filterscipy.signal.iirnotch()- Notch filter
See also
SciPy Signal Processing - Full SciPy signal module documentation
Pattern 2: Biquad Formulas#
Use direct biquad formulas for second-order sections:
class CustomPeakingEQ(AbstractFilter):
"""Peaking EQ filter using biquad formulas."""
def __init__(self, freq: float, gain_db: float, q: float = 1.0, fs: int | None = None):
super().__init__()
self.freq = freq
self.gain_db = gain_db
self.q = q
self.fs = fs
self.a = None
self.b = None
@property
def _omega(self) -> float:
"""Angular frequency."""
return 2 * np.pi * self.freq / self.fs
@property
def _alpha(self) -> float:
"""Biquad alpha parameter."""
return np.sin(self._omega) / (2 * self.q)
def compute_coefficients(self) -> None:
if self.fs is None:
raise ValueError("Sample rate must be set")
A = 10 ** (self.gain_db / 40) # Linear gain
omega = self._omega
alpha = self._alpha
cos_omega = np.cos(omega)
# Biquad coefficients for peaking EQ
b0 = 1 + alpha * A
b1 = -2 * cos_omega
b2 = 1 - alpha * A
a0 = 1 + alpha / A
a1 = -2 * cos_omega
a2 = 1 - alpha / A
# Normalize by a0
self.b = [b0 / a0, b1 / a0, b2 / a0]
self.a = [1.0, a1 / a0, a2 / a0]
Biquad transfer function:
Normalized form (dividing by \(a_0\)):
See also
Digital Biquad Filter on Wikipedia - Biquad filter theory
Pattern 3: Cascaded Filters#
Chain multiple filter stages by convolving coefficients:
class LinkwitzRiley(AbstractFilter):
"""Linkwitz-Riley crossover filter (cascaded Butterworth)."""
def __init__(self, cutoff: float, order: int = 4, btype: str = 'low', fs: int | None = None):
super().__init__()
assert order % 2 == 0, "Linkwitz-Riley order must be even"
self.cutoff = cutoff
self.order = order
self.btype = btype
self.fs = fs
self.a = None
self.b = None
def compute_coefficients(self) -> None:
if self.fs is None:
raise ValueError("Sample rate must be set")
# Linkwitz-Riley is two cascaded Butterworth filters
butter_order = self.order // 2
# Get base Butterworth coefficients
b_butter, a_butter = butter(
butter_order,
self.cutoff / (0.5 * self.fs),
btype=self.btype
)
# Cascade by convolving coefficients
self.b = np.convolve(b_butter, b_butter)
self.a = np.convolve(a_butter, a_butter)
Cascading filters: Convolving filter coefficients is equivalent to cascading filters in series.
Pattern 4: Custom Coefficient Computation#
Implement your own filter design algorithms:
class CustomResonator(AbstractFilter):
"""Resonant filter with custom coefficient computation."""
def __init__(self, freq: float, resonance: float = 0.5, fs: int | None = None):
super().__init__()
self.freq = freq
self.resonance = np.clip(resonance, 0.0, 0.99) # Stability constraint
self.fs = fs
self.a = None
self.b = None
def compute_coefficients(self) -> None:
if self.fs is None:
raise ValueError("Sample rate must be set")
# Normalized frequency
omega = 2.0 * np.pi * self.freq / self.fs
# Quality factor from resonance parameter
Q = 1.0 / (1.0 - self.resonance)
# Compute coefficients using resonator formulas
alpha = np.sin(omega) / (2.0 * Q)
cos_omega = np.cos(omega)
# Resonant lowpass coefficients
b0 = (1.0 - cos_omega) / 2.0
b1 = 1.0 - cos_omega
b2 = (1.0 - cos_omega) / 2.0
a0 = 1.0 + alpha
a1 = -2.0 * cos_omega
a2 = 1.0 - alpha
# Normalize by a0
self.b = [b0 / a0, b1 / a0, b2 / a0]
self.a = [1.0, a1 / a0, a2 / a0]
Pipeline Integration#
Custom filters automatically work with TorchFX’s pipeline system:
Series Combination (Pipe Operator)#
import torchfx as fx
wave = fx.Wave.from_file("audio.wav")
# Chain custom filters in series
custom_bp = CustomBandpass(200, 2000, order=4)
custom_res = CustomResonator(1000, resonance=0.7)
processed = wave | custom_bp | custom_res
sequenceDiagram
participant Wave
participant CustomBandpass
participant CustomResonator
Wave->>CustomBandpass: wave | custom_bp
Note over CustomBandpass: fs auto-configured
CustomBandpass->>CustomBandpass: compute_coefficients()
CustomBandpass->>CustomBandpass: forward(wave.ys)
CustomBandpass->>Wave: Return new Wave
Wave->>CustomResonator: result | custom_res
Note over CustomResonator: fs auto-configured
CustomResonator->>CustomResonator: compute_coefficients()
CustomResonator->>CustomResonator: forward(result.ys)
CustomResonator->>Wave: Return final Wave
Parallel Combination (Addition Operator)#
# Parallel combination (sum outputs)
bandpass1 = CustomBandpass(200, 500)
bandpass2 = CustomBandpass(1000, 2000)
parallel = bandpass1 + bandpass2 # Creates ParallelFilterCombination
processed = wave | parallel
The + operator is inherited from AbstractFilter and automatically creates a ParallelFilterCombination.
See also
Series and Parallel Filter Combinations - Detailed guide on combining filters
Device and Dtype Management#
Automatic Device Handling#
Filters automatically handle CPU/GPU transfer:
import torchfx as fx
import torch
wave = fx.Wave.from_file("audio.wav")
# Move to GPU
if torch.cuda.is_available():
wave = wave.to("cuda")
# Filter runs on GPU automatically
filtered = wave | CustomBandpass(200, 2000)
# Move back to CPU for saving
filtered.to("cpu").save("output.wav")
Helper Method for Coefficient Transfer#
Include a helper method for moving coefficients:
class CustomFilter(AbstractFilter):
def move_coeff(self, device: torch.device, dtype: torch.dtype) -> None:
"""Move filter coefficients to specified device and dtype."""
self.a = torch.as_tensor(self.a, device=device, dtype=dtype)
self.b = torch.as_tensor(self.b, device=device, dtype=dtype)
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
if self.a is None or self.b is None:
self.compute_coefficients()
# Use helper method for device transfer
if not isinstance(self.a, Tensor):
self.move_coeff(x.device, x.dtype)
return lfilter(x, self.a, self.b)
Testing and Validation#
Coefficient Stability Check#
For IIR filters, ensure poles are inside the unit circle:
def compute_coefficients(self) -> None:
# ... compute coefficients ...
# Validate stability
roots = np.roots(self.a)
if np.any(np.abs(roots) >= 1.0):
raise ValueError(
f"Filter is unstable! Poles outside unit circle: {roots[np.abs(roots) >= 1.0]}"
)
Frequency Response Visualization#
Plot the filter’s frequency response:
def plot_response(self) -> None:
"""Plot magnitude and phase response."""
from scipy.signal import freqz
import matplotlib.pyplot as plt
if self.a is None or self.b is None:
self.compute_coefficients()
# Compute frequency response
w, h = freqz(self.b, self.a, worN=2000, fs=self.fs)
# Plot magnitude
plt.figure(figsize=(10, 6))
plt.subplot(2, 1, 1)
plt.plot(w, 20 * np.log10(np.abs(h)))
plt.title('Frequency Response')
plt.ylabel('Magnitude [dB]')
plt.xlabel('Frequency [Hz]')
plt.grid(True)
# Plot phase
plt.subplot(2, 1, 2)
plt.plot(w, np.angle(h))
plt.ylabel('Phase [radians]')
plt.xlabel('Frequency [Hz]')
plt.grid(True)
plt.tight_layout()
plt.show()
# Usage
bandpass = CustomBandpass(200, 2000, fs=44100)
bandpass.plot_response()
Unit Test Example#
def test_custom_bandpass():
"""Test custom bandpass filter."""
import torch
import torchfx as fx
# Create test signal (440 Hz + 880 Hz tones)
fs = 44100
duration = 1.0
t = torch.linspace(0, duration, int(fs * duration))
signal = (
torch.sin(2 * torch.pi * 440 * t) +
torch.sin(2 * torch.pi * 880 * t)
)
wave = fx.Wave(signal.unsqueeze(0), fs)
# Create and apply filter (passes 440 Hz, attenuates 880 Hz)
bandpass = CustomBandpass(low_cutoff=300, high_cutoff=600, fs=fs)
filtered = wave | bandpass
# Verify output shape
assert filtered.ys.shape == wave.ys.shape
# Verify filtering occurred
assert not torch.allclose(filtered.ys, wave.ys)
# Verify coefficients computed
assert bandpass.a is not None
assert bandpass.b is not None
assert isinstance(bandpass.a, Tensor)
assert isinstance(bandpass.b, Tensor)
print("✓ All tests passed!")
Best Practices#
Parameter Validation#
Validate parameters in __init__:
def __init__(self, cutoff: float, q: float = 1.0, fs: int | None = None):
super().__init__()
# Validate parameters
assert cutoff > 0, "Cutoff frequency must be positive"
assert q > 0, "Q factor must be positive"
if fs is not None:
assert fs > 0, "Sample rate must be positive"
assert cutoff < fs / 2, "Cutoff must be below Nyquist frequency"
self.cutoff = cutoff
self.q = q
self.fs = fs
Use Properties for Computed Values#
@property
def _omega(self) -> float:
"""Normalized angular frequency."""
return 2 * np.pi * self.cutoff / self.fs
@property
def _q_factor(self) -> float:
"""Quality factor from bandwidth."""
return self.cutoff / self.bandwidth
Document Thoroughly#
Include comprehensive docstrings:
class CustomFilter(AbstractFilter):
"""One-line summary.
Longer description explaining what the filter does,
its characteristics, and when to use it.
Parameters
----------
param1 : type
Description of param1
param2 : type, optional
Description of param2 (default: value)
fs : int, optional
Sample rate in Hz (default: None, auto-configured from Wave)
Attributes
----------
a : Tensor | None
Denominator coefficients
b : Tensor | None
Numerator coefficients
Examples
--------
>>> import torchfx as fx
>>> wave = fx.Wave.from_file("audio.wav")
>>> filt = CustomFilter(param1=value, param2=value)
>>> filtered = wave | filt
Notes
-----
Additional technical notes, references, or mathematical formulations.
References
----------
.. [1] Author, "Title," Journal, Year.
See Also
--------
RelatedFilter : Related filter implementation
"""
Handle Edge Cases#
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
# Handle empty input
if x.numel() == 0:
return x
# Handle very short signals
min_length = self.order * 3
if x.shape[-1] < min_length:
import warnings
warnings.warn(
f"Signal length ({x.shape[-1]}) is shorter than recommended "
f"({min_length} samples). Results may be unreliable."
)
# Normal processing
# ...
Advanced Topics#
State-ful Filters#
For real-time processing, maintain filter state:
class StatefulFilter(AbstractFilter):
"""Filter that maintains state for real-time processing."""
def __init__(self, cutoff: float, fs: int | None = None):
super().__init__()
self.cutoff = cutoff
self.fs = fs
self.a = None
self.b = None
self.zi = None # Initial conditions
def reset_state(self) -> None:
"""Reset filter state to zero."""
if self.a is not None and self.b is not None:
from scipy.signal import lfilter_zi
self.zi = lfilter_zi(self.b, self.a)
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
# ... coefficient computation ...
# Use lfilter with initial conditions
from torchaudio.functional import lfilter
result = lfilter(x, self.a, self.b, zi=self.zi)
# Update state for next call
# (zi updated by lfilter in-place)
return result
Differentiable Filters#
Enable gradients for learned filter parameters:
class LearnableFilter(AbstractFilter):
"""Filter with learnable cutoff frequency."""
def __init__(self, initial_cutoff: float = 1000, fs: int | None = None):
super().__init__()
self.fs = fs
# Learnable parameter (log-scale for numerical stability)
self.log_cutoff = torch.nn.Parameter(
torch.log(torch.tensor([initial_cutoff]))
)
@property
def cutoff(self) -> float:
"""Current cutoff frequency."""
return torch.exp(self.log_cutoff).item()
def forward(self, x: Tensor) -> Tensor:
# Recompute coefficients each forward pass (cutoff may have changed)
self.compute_coefficients()
# DON'T use @torch.no_grad() - we need gradients!
return lfilter(x, self.a, self.b)
External Resources#
SciPy Signal Processing - SciPy filter design functions
Digital Filter Design on Wikipedia - Filter theory
IIR Filter Design - Julius O. Smith’s filter design book
Audio EQ Cookbook - Biquad filter formulas