Creating Custom Filters#

Learn how to build your own digital filters by extending TorchFX’s AbstractFilter base class. This tutorial covers filter design patterns, coefficient computation, and integration with the TorchFX pipeline system.

Overview#

Custom filters in TorchFX inherit from AbstractFilter, which provides the foundation for:

  • Pipeline integration: Automatic support for the | operator

  • Parallel combination: Built-in + operator for filter banks

  • PyTorch compatibility: Inherits from torch.nn.Module

  • Device management: Automatic GPU/CPU handling

  • Sample rate configuration: Automatic fs propagation from Wave

        classDiagram
    class Module {
        <<PyTorch>>
        +forward(x)*
        +to(device)
        +parameters()
    }

    class FX {
        <<abstract>>
        +forward(x) Tensor*
    }

    class AbstractFilter {
        <<abstract>>
        +compute_coefficients()*
        +__add__(other) ParallelFilterCombination
        +_has_computed_coeff bool
    }

    class IIR {
        <<abstract>>
        +fs int|None
        +a Tensor|None
        +b Tensor|None
        +move_coeff(device, dtype)
        +forward(x) Tensor
    }

    class CustomFilter {
        +__init__(...)
        +compute_coefficients()
        +forward(x) Tensor
    }

    Module <|-- FX
    FX <|-- AbstractFilter
    AbstractFilter <|-- IIR
    AbstractFilter <|-- CustomFilter

    note for Module "PyTorch base class"
    note for AbstractFilter "Adds filter-specific features"
    note for CustomFilter "Your implementation"
    

Required Methods#

Every custom filter must implement these three methods:

1. __init__ - Initialization#

Initialize filter parameters and call parent constructor:

from torchfx.filter import AbstractFilter

class CustomFilter(AbstractFilter):
    def __init__(self, param1: float, param2: int, fs: int | None = None):
        super().__init__()  # REQUIRED: Initialize parent class

        # Store filter parameters
        self.param1 = param1
        self.param2 = param2
        self.fs = fs

        # Initialize coefficient storage
        self.a = None
        self.b = None

Key points:

  • Always call super().__init__() first

  • Accept fs parameter (can be None)

  • Initialize coefficient attributes to None

  • Validate parameters if needed

2. compute_coefficients - Filter Design#

Compute filter coefficients based on parameters:

def compute_coefficients(self) -> None:
    """Compute filter coefficients."""
    # Verify fs is set
    if self.fs is None:
        raise ValueError("Sample rate must be set before computing coefficients")

    # Design filter (example using scipy)
    from scipy.signal import butter

    # Normalize frequency to Nyquist
    nyquist = 0.5 * self.fs
    normalized_freq = self.cutoff / nyquist

    # Compute coefficients
    self.b, self.a = butter(self.order, normalized_freq, btype='low')

Key points:

  • Check that fs is not None

  • Use SciPy or custom formulas

  • Set self.a and self.b attributes

  • Only computed once (cached automatically)

3. forward - Apply Filter#

Process audio through the filter:

import torch
from torch import Tensor
from torchaudio.functional import lfilter

@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
    """Apply filter to audio tensor."""
    # Compute coefficients if not already done
    if self.a is None or self.b is None:
        self.compute_coefficients()

    # Convert to tensors if needed (move to correct device)
    if not isinstance(self.a, Tensor):
        self.a = torch.as_tensor(self.a, device=x.device, dtype=x.dtype)
        self.b = torch.as_tensor(self.b, device=x.device, dtype=x.dtype)

    # Apply IIR filter
    return lfilter(x, self.a, self.b)

Key points:

  • Use @torch.no_grad() for efficiency

  • Lazy coefficient computation

  • Handle device/dtype conversion

  • Use torchaudio.functional.lfilter() for IIR filters

Complete Example: Custom Bandpass Filter#

Here’s a complete, working example of a custom bandpass filter:

import numpy as np
import torch
from torch import Tensor
from scipy.signal import butter
from torchaudio.functional import lfilter

from torchfx.filter import AbstractFilter

class CustomBandpass(AbstractFilter):
    """Custom bandpass filter using Butterworth design.

    Parameters
    ----------
    low_cutoff : float
        Lower cutoff frequency in Hz
    high_cutoff : float
        Upper cutoff frequency in Hz
    order : int, optional
        Filter order (default: 4)
    fs : int, optional
        Sample rate in Hz (can be set later)

    Examples
    --------
    >>> import torchfx as fx
    >>> wave = fx.Wave.from_file("audio.wav")
    >>> bandpass = CustomBandpass(low_cutoff=200, high_cutoff=2000, order=4)
    >>> filtered = wave | bandpass
    """

    def __init__(
        self,
        low_cutoff: float,
        high_cutoff: float,
        order: int = 4,
        fs: int | None = None,
    ):
        super().__init__()

        # Validate parameters
        assert low_cutoff > 0, "Low cutoff must be positive"
        assert high_cutoff > low_cutoff, "High cutoff must be > low cutoff"
        assert order > 0, "Order must be positive"

        # Store parameters
        self.low_cutoff = low_cutoff
        self.high_cutoff = high_cutoff
        self.order = order
        self.fs = fs

        # Initialize coefficients
        self.a = None
        self.b = None

    def compute_coefficients(self) -> None:
        """Compute Butterworth bandpass coefficients."""
        if self.fs is None:
            raise ValueError("Sample rate must be set before computing coefficients")

        # Normalize frequencies to Nyquist frequency
        nyquist = 0.5 * self.fs
        low_norm = self.low_cutoff / nyquist
        high_norm = self.high_cutoff / nyquist

        # Validate normalized frequencies
        if not (0 < low_norm < 1 and 0 < high_norm < 1):
            raise ValueError(
                f"Cutoff frequencies must be between 0 and Nyquist ({nyquist} Hz)"
            )

        # Design Butterworth bandpass filter
        self.b, self.a = butter(
            self.order,
            [low_norm, high_norm],
            btype='bandpass'
        )

    @torch.no_grad()
    def forward(self, x: Tensor) -> Tensor:
        """Apply bandpass filter to input tensor.

        Parameters
        ----------
        x : Tensor
            Input audio tensor of shape (channels, samples)

        Returns
        -------
        Tensor
            Filtered audio tensor
        """
        # Lazy coefficient computation
        if self.a is None or self.b is None:
            self.compute_coefficients()

        # Convert coefficients to tensors (match input device/dtype)
        if not isinstance(self.a, Tensor):
            self.a = torch.as_tensor(self.a, device=x.device, dtype=x.dtype)
            self.b = torch.as_tensor(self.b, device=x.device, dtype=x.dtype)

        # Apply filter
        return lfilter(x, self.a, self.b)

Using the Custom Filter#

import torchfx as fx

# Load audio
wave = fx.Wave.from_file("audio.wav")

# Create and apply filter (fs auto-configured from wave)
bandpass = CustomBandpass(low_cutoff=200, high_cutoff=2000, order=4)
filtered = wave | bandpass

# Save result
filtered.save("bandpass_filtered.wav")

# Can also chain with other filters
from torchfx.filter import iir

processed = (
    wave
    | iir.HiButterworth(cutoff=80, order=2)  # Remove rumble
    | bandpass                                # Bandpass 200-2000 Hz
    | fx.effect.Normalize()                   # Normalize
)

Filter Design Patterns#

Pattern 1: SciPy-Based Filters#

Use SciPy’s signal processing functions for coefficient design:

from scipy.signal import butter, cheby1, cheby2, ellip, iirnotch, iirpeak

class CustomLowpass(AbstractFilter):
    def compute_coefficients(self) -> None:
        nyquist = 0.5 * self.fs
        norm_freq = self.cutoff / nyquist

        # Choose a design function
        self.b, self.a = butter(self.order, norm_freq, btype='low')
        # self.b, self.a = cheby1(self.order, self.ripple, norm_freq, btype='low')
        # self.b, self.a = ellip(self.order, self.ripple, self.atten, norm_freq, btype='low')

Available SciPy filters:

  • scipy.signal.butter() - Butterworth (maximally flat passband)

  • scipy.signal.cheby1() - Chebyshev Type I (ripple in passband)

  • scipy.signal.cheby2() - Chebyshev Type II (ripple in stopband)

  • scipy.signal.ellip() - Elliptic (ripple in both bands)

  • scipy.signal.iirpeak() - Peaking filter

  • scipy.signal.iirnotch() - Notch filter

See also

SciPy Signal Processing - Full SciPy signal module documentation

Pattern 2: Biquad Formulas#

Use direct biquad formulas for second-order sections:

class CustomPeakingEQ(AbstractFilter):
    """Peaking EQ filter using biquad formulas."""

    def __init__(self, freq: float, gain_db: float, q: float = 1.0, fs: int | None = None):
        super().__init__()
        self.freq = freq
        self.gain_db = gain_db
        self.q = q
        self.fs = fs
        self.a = None
        self.b = None

    @property
    def _omega(self) -> float:
        """Angular frequency."""
        return 2 * np.pi * self.freq / self.fs

    @property
    def _alpha(self) -> float:
        """Biquad alpha parameter."""
        return np.sin(self._omega) / (2 * self.q)

    def compute_coefficients(self) -> None:
        if self.fs is None:
            raise ValueError("Sample rate must be set")

        A = 10 ** (self.gain_db / 40)  # Linear gain
        omega = self._omega
        alpha = self._alpha
        cos_omega = np.cos(omega)

        # Biquad coefficients for peaking EQ
        b0 = 1 + alpha * A
        b1 = -2 * cos_omega
        b2 = 1 - alpha * A

        a0 = 1 + alpha / A
        a1 = -2 * cos_omega
        a2 = 1 - alpha / A

        # Normalize by a0
        self.b = [b0 / a0, b1 / a0, b2 / a0]
        self.a = [1.0, a1 / a0, a2 / a0]

Biquad transfer function:

\[ H(z) = \frac{b_0 + b_1 z^{-1} + b_2 z^{-2}}{a_0 + a_1 z^{-1} + a_2 z^{-2}} \]

Normalized form (dividing by \(a_0\)):

\[ H(z) = \frac{b_0' + b_1' z^{-1} + b_2' z^{-2}}{1 + a_1' z^{-1} + a_2' z^{-2}} \]

See also

Digital Biquad Filter on Wikipedia - Biquad filter theory

Pattern 3: Cascaded Filters#

Chain multiple filter stages by convolving coefficients:

class LinkwitzRiley(AbstractFilter):
    """Linkwitz-Riley crossover filter (cascaded Butterworth)."""

    def __init__(self, cutoff: float, order: int = 4, btype: str = 'low', fs: int | None = None):
        super().__init__()
        assert order % 2 == 0, "Linkwitz-Riley order must be even"

        self.cutoff = cutoff
        self.order = order
        self.btype = btype
        self.fs = fs
        self.a = None
        self.b = None

    def compute_coefficients(self) -> None:
        if self.fs is None:
            raise ValueError("Sample rate must be set")

        # Linkwitz-Riley is two cascaded Butterworth filters
        butter_order = self.order // 2

        # Get base Butterworth coefficients
        b_butter, a_butter = butter(
            butter_order,
            self.cutoff / (0.5 * self.fs),
            btype=self.btype
        )

        # Cascade by convolving coefficients
        self.b = np.convolve(b_butter, b_butter)
        self.a = np.convolve(a_butter, a_butter)

Cascading filters: Convolving filter coefficients is equivalent to cascading filters in series.

Pattern 4: Custom Coefficient Computation#

Implement your own filter design algorithms:

class CustomResonator(AbstractFilter):
    """Resonant filter with custom coefficient computation."""

    def __init__(self, freq: float, resonance: float = 0.5, fs: int | None = None):
        super().__init__()
        self.freq = freq
        self.resonance = np.clip(resonance, 0.0, 0.99)  # Stability constraint
        self.fs = fs
        self.a = None
        self.b = None

    def compute_coefficients(self) -> None:
        if self.fs is None:
            raise ValueError("Sample rate must be set")

        # Normalized frequency
        omega = 2.0 * np.pi * self.freq / self.fs

        # Quality factor from resonance parameter
        Q = 1.0 / (1.0 - self.resonance)

        # Compute coefficients using resonator formulas
        alpha = np.sin(omega) / (2.0 * Q)
        cos_omega = np.cos(omega)

        # Resonant lowpass coefficients
        b0 = (1.0 - cos_omega) / 2.0
        b1 = 1.0 - cos_omega
        b2 = (1.0 - cos_omega) / 2.0

        a0 = 1.0 + alpha
        a1 = -2.0 * cos_omega
        a2 = 1.0 - alpha

        # Normalize by a0
        self.b = [b0 / a0, b1 / a0, b2 / a0]
        self.a = [1.0, a1 / a0, a2 / a0]

Pipeline Integration#

Custom filters automatically work with TorchFX’s pipeline system:

Series Combination (Pipe Operator)#

import torchfx as fx

wave = fx.Wave.from_file("audio.wav")

# Chain custom filters in series
custom_bp = CustomBandpass(200, 2000, order=4)
custom_res = CustomResonator(1000, resonance=0.7)

processed = wave | custom_bp | custom_res
        sequenceDiagram
    participant Wave
    participant CustomBandpass
    participant CustomResonator

    Wave->>CustomBandpass: wave | custom_bp
    Note over CustomBandpass: fs auto-configured
    CustomBandpass->>CustomBandpass: compute_coefficients()
    CustomBandpass->>CustomBandpass: forward(wave.ys)
    CustomBandpass->>Wave: Return new Wave

    Wave->>CustomResonator: result | custom_res
    Note over CustomResonator: fs auto-configured
    CustomResonator->>CustomResonator: compute_coefficients()
    CustomResonator->>CustomResonator: forward(result.ys)
    CustomResonator->>Wave: Return final Wave
    

Parallel Combination (Addition Operator)#

# Parallel combination (sum outputs)
bandpass1 = CustomBandpass(200, 500)
bandpass2 = CustomBandpass(1000, 2000)

parallel = bandpass1 + bandpass2  # Creates ParallelFilterCombination
processed = wave | parallel

The + operator is inherited from AbstractFilter and automatically creates a ParallelFilterCombination.

See also

Series and Parallel Filter Combinations - Detailed guide on combining filters

Device and Dtype Management#

Automatic Device Handling#

Filters automatically handle CPU/GPU transfer:

import torchfx as fx
import torch

wave = fx.Wave.from_file("audio.wav")

# Move to GPU
if torch.cuda.is_available():
    wave = wave.to("cuda")

    # Filter runs on GPU automatically
    filtered = wave | CustomBandpass(200, 2000)

    # Move back to CPU for saving
    filtered.to("cpu").save("output.wav")

Helper Method for Coefficient Transfer#

Include a helper method for moving coefficients:

class CustomFilter(AbstractFilter):
    def move_coeff(self, device: torch.device, dtype: torch.dtype) -> None:
        """Move filter coefficients to specified device and dtype."""
        self.a = torch.as_tensor(self.a, device=device, dtype=dtype)
        self.b = torch.as_tensor(self.b, device=device, dtype=dtype)

    @torch.no_grad()
    def forward(self, x: Tensor) -> Tensor:
        if self.a is None or self.b is None:
            self.compute_coefficients()

        # Use helper method for device transfer
        if not isinstance(self.a, Tensor):
            self.move_coeff(x.device, x.dtype)

        return lfilter(x, self.a, self.b)

Testing and Validation#

Coefficient Stability Check#

For IIR filters, ensure poles are inside the unit circle:

def compute_coefficients(self) -> None:
    # ... compute coefficients ...

    # Validate stability
    roots = np.roots(self.a)
    if np.any(np.abs(roots) >= 1.0):
        raise ValueError(
            f"Filter is unstable! Poles outside unit circle: {roots[np.abs(roots) >= 1.0]}"
        )

Frequency Response Visualization#

Plot the filter’s frequency response:

def plot_response(self) -> None:
    """Plot magnitude and phase response."""
    from scipy.signal import freqz
    import matplotlib.pyplot as plt

    if self.a is None or self.b is None:
        self.compute_coefficients()

    # Compute frequency response
    w, h = freqz(self.b, self.a, worN=2000, fs=self.fs)

    # Plot magnitude
    plt.figure(figsize=(10, 6))
    plt.subplot(2, 1, 1)
    plt.plot(w, 20 * np.log10(np.abs(h)))
    plt.title('Frequency Response')
    plt.ylabel('Magnitude [dB]')
    plt.xlabel('Frequency [Hz]')
    plt.grid(True)

    # Plot phase
    plt.subplot(2, 1, 2)
    plt.plot(w, np.angle(h))
    plt.ylabel('Phase [radians]')
    plt.xlabel('Frequency [Hz]')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Usage
bandpass = CustomBandpass(200, 2000, fs=44100)
bandpass.plot_response()

Unit Test Example#

def test_custom_bandpass():
    """Test custom bandpass filter."""
    import torch
    import torchfx as fx

    # Create test signal (440 Hz + 880 Hz tones)
    fs = 44100
    duration = 1.0
    t = torch.linspace(0, duration, int(fs * duration))
    signal = (
        torch.sin(2 * torch.pi * 440 * t) +
        torch.sin(2 * torch.pi * 880 * t)
    )

    wave = fx.Wave(signal.unsqueeze(0), fs)

    # Create and apply filter (passes 440 Hz, attenuates 880 Hz)
    bandpass = CustomBandpass(low_cutoff=300, high_cutoff=600, fs=fs)
    filtered = wave | bandpass

    # Verify output shape
    assert filtered.ys.shape == wave.ys.shape

    # Verify filtering occurred
    assert not torch.allclose(filtered.ys, wave.ys)

    # Verify coefficients computed
    assert bandpass.a is not None
    assert bandpass.b is not None
    assert isinstance(bandpass.a, Tensor)
    assert isinstance(bandpass.b, Tensor)

    print("✓ All tests passed!")

Best Practices#

Parameter Validation#

Validate parameters in __init__:

def __init__(self, cutoff: float, q: float = 1.0, fs: int | None = None):
    super().__init__()

    # Validate parameters
    assert cutoff > 0, "Cutoff frequency must be positive"
    assert q > 0, "Q factor must be positive"
    if fs is not None:
        assert fs > 0, "Sample rate must be positive"
        assert cutoff < fs / 2, "Cutoff must be below Nyquist frequency"

    self.cutoff = cutoff
    self.q = q
    self.fs = fs

Use Properties for Computed Values#

@property
def _omega(self) -> float:
    """Normalized angular frequency."""
    return 2 * np.pi * self.cutoff / self.fs

@property
def _q_factor(self) -> float:
    """Quality factor from bandwidth."""
    return self.cutoff / self.bandwidth

Document Thoroughly#

Include comprehensive docstrings:

class CustomFilter(AbstractFilter):
    """One-line summary.

    Longer description explaining what the filter does,
    its characteristics, and when to use it.

    Parameters
    ----------
    param1 : type
        Description of param1
    param2 : type, optional
        Description of param2 (default: value)
    fs : int, optional
        Sample rate in Hz (default: None, auto-configured from Wave)

    Attributes
    ----------
    a : Tensor | None
        Denominator coefficients
    b : Tensor | None
        Numerator coefficients

    Examples
    --------
    >>> import torchfx as fx
    >>> wave = fx.Wave.from_file("audio.wav")
    >>> filt = CustomFilter(param1=value, param2=value)
    >>> filtered = wave | filt

    Notes
    -----
    Additional technical notes, references, or mathematical formulations.

    References
    ----------
    .. [1] Author, "Title," Journal, Year.

    See Also
    --------
    RelatedFilter : Related filter implementation
    """

Handle Edge Cases#

@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
    # Handle empty input
    if x.numel() == 0:
        return x

    # Handle very short signals
    min_length = self.order * 3
    if x.shape[-1] < min_length:
        import warnings
        warnings.warn(
            f"Signal length ({x.shape[-1]}) is shorter than recommended "
            f"({min_length} samples). Results may be unreliable."
        )

    # Normal processing
    # ...

Advanced Topics#

State-ful Filters#

For real-time processing, maintain filter state:

class StatefulFilter(AbstractFilter):
    """Filter that maintains state for real-time processing."""

    def __init__(self, cutoff: float, fs: int | None = None):
        super().__init__()
        self.cutoff = cutoff
        self.fs = fs
        self.a = None
        self.b = None
        self.zi = None  # Initial conditions

    def reset_state(self) -> None:
        """Reset filter state to zero."""
        if self.a is not None and self.b is not None:
            from scipy.signal import lfilter_zi
            self.zi = lfilter_zi(self.b, self.a)

    @torch.no_grad()
    def forward(self, x: Tensor) -> Tensor:
        # ... coefficient computation ...

        # Use lfilter with initial conditions
        from torchaudio.functional import lfilter
        result = lfilter(x, self.a, self.b, zi=self.zi)

        # Update state for next call
        # (zi updated by lfilter in-place)

        return result

Differentiable Filters#

Enable gradients for learned filter parameters:

class LearnableFilter(AbstractFilter):
    """Filter with learnable cutoff frequency."""

    def __init__(self, initial_cutoff: float = 1000, fs: int | None = None):
        super().__init__()
        self.fs = fs

        # Learnable parameter (log-scale for numerical stability)
        self.log_cutoff = torch.nn.Parameter(
            torch.log(torch.tensor([initial_cutoff]))
        )

    @property
    def cutoff(self) -> float:
        """Current cutoff frequency."""
        return torch.exp(self.log_cutoff).item()

    def forward(self, x: Tensor) -> Tensor:
        # Recompute coefficients each forward pass (cutoff may have changed)
        self.compute_coefficients()

        # DON'T use @torch.no_grad() - we need gradients!
        return lfilter(x, self.a, self.b)

External Resources#

References#