(pytorch-integration)= # PyTorch Integration Learn how TorchFX seamlessly integrates with PyTorch's neural network ecosystem. This guide covers using TorchFX modules with {class}`torch.nn.Sequential`, creating custom modules, gradient computation, and mixing with torchaudio transforms. ## Prerequisites Before starting this tutorial, you should be familiar with: - {doc}`../core-concepts/wave` - Wave class fundamentals - {doc}`../core-concepts/pipeline-operator` - Pipeline operator basics - {doc}`../core-concepts/fx` - FX base class architecture - [PyTorch nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) - Module fundamentals - Basic PyTorch concepts (tensors, device management, forward passes) ## Overview All TorchFX audio effects and filters are implemented as subclasses of {class}`torch.nn.Module`, making them fully compatible with PyTorch's neural network ecosystem. This design enables: - **Seamless composition** with other PyTorch modules - **Integration** into neural network training pipelines - **Automatic gradient computation** for differentiable operations - **Device management** through PyTorch's standard API - **Compatibility** with PyTorch utilities like {class}`torch.nn.Sequential` and {class}`torch.nn.ModuleList` ```{tip} TorchFX modules behave exactly like standard PyTorch modules, so you can use them anywhere you'd use a PyTorch layer or transform. ``` ## Core Integration Architecture TorchFX's integration with PyTorch is built on inheritance from {class}`torch.nn.Module`. Understanding this architecture helps you leverage the full power of both libraries. ```{mermaid} graph TB subgraph "PyTorch Framework" Module["torch.nn.Module
Base class for all neural network modules"] end subgraph "torchfx Core" FX["FX (Abstract Base)
Inherits from torch.nn.Module
src/torchfx/effect.py"] subgraph "Effects" Gain["Gain"] Normalize["Normalize"] Reverb["Reverb"] Delay["Delay"] end subgraph "Filters" AbstractFilter["AbstractFilter"] IIR["IIR Filters
(Butterworth, Chebyshev)"] FIR["FIR Filters
(DesignableFIR)"] end Wave["Wave
Audio data container
Implements pipe operator"] end subgraph "PyTorch Containers" Sequential["nn.Sequential
Sequential composition"] ModuleList["nn.ModuleList
Module container"] Custom["Custom nn.Module
User-defined classes"] end Module -->|"inherits"| FX FX -->|"base for"| Gain FX -->|"base for"| Normalize FX -->|"base for"| Reverb FX -->|"base for"| Delay FX -->|"base for"| AbstractFilter AbstractFilter -->|"base for"| IIR AbstractFilter -->|"base for"| FIR Wave -->|"pipe operator accepts"| Module Wave -->|"works with"| Sequential Wave -->|"works with"| ModuleList Wave -->|"works with"| Custom Custom -->|"can contain"| FX Sequential -->|"can contain"| FX ModuleList -->|"can contain"| FX style Module fill:#e1f5ff style FX fill:#fff5e1 style Wave fill:#e1ffe1 ``` **TorchFX Module Hierarchy** - All TorchFX effects and filters inherit from {class}`torch.nn.Module`, enabling seamless integration with the PyTorch ecosystem. ## Module Inheritance Benefits Because TorchFX modules inherit from {class}`torch.nn.Module`, they automatically gain all PyTorch module capabilities: | Feature | Benefit | Example | |---------|---------|---------| | **Parameter Registration** | Filter coefficients tracked by PyTorch | `list(module.parameters())` | | **Device Management** | `.to(device)` moves all tensors | `filter.to("cuda")` | | **State Dict** | Serialization and deserialization | `torch.save(module.state_dict(), "model.pt")` | | **Training Mode** | `.train()` and `.eval()` support | `module.eval()` | | **Nested Modules** | Automatic recursive operations | `module.to("cuda")` moves all children | | **Hooks** | Register forward/backward hooks | `module.register_forward_hook(hook_fn)` | ```{seealso} {doc}`gpu-acceleration` - Using device management for GPU acceleration ``` ## Wave Pipe Operator with nn.Module The {class}`~torchfx.Wave` class implements the {term}`pipeline operator` (`|`) to accept **any** {class}`torch.nn.Module`, not just TorchFX-specific effects and filters. This design choice enables integration with the entire PyTorch ecosystem. ### Implementation Details The pipe operator performs the following steps: ```{mermaid} sequenceDiagram participant User participant Wave participant Module as "nn.Module" participant FX as "FX Instance?" User->>Wave: wave | module Wave->>Wave: __or__(module) Wave->>Module: Check isinstance(module, nn.Module) alt Not an nn.Module Wave-->>User: TypeError end Wave->>FX: Check isinstance(module, FX) alt Is FX instance Wave->>Module: Set module.fs = wave.fs Wave->>Module: Call compute_coefficients() alt Is Sequential/ModuleList Wave->>Module: Recursively configure FX children end end Wave->>Module: Call module.forward(wave.ys) Module-->>Wave: Return processed tensor Wave->>Wave: Create new Wave(result, fs) Wave-->>User: Return new Wave ``` **Pipeline Operator Flow** - Shows how the pipe operator processes any {class}`torch.nn.Module`. ### Type Validation and Configuration | Step | Action | Code Reference | |------|--------|----------------| | 1. Type validation | Checks if right operand is {class}`torch.nn.Module` | `src/torchfx/wave.py:163-164` | | 2. FX configuration | Updates `fs` and computes coefficients for FX instances | `src/torchfx/wave.py:166-172` | | 3. Sequential handling | Recursively configures FX instances in Sequential/ModuleList | `src/torchfx/wave.py:169-172` | | 4. Forward pass | Applies module's `forward` method to audio tensor | `src/torchfx/wave.py:174` | ### Usage Patterns The pipe operator works with any module that implements a `forward` method accepting and returning tensors: ```python import torch import torch.nn as nn import torchfx as fx wave = fx.Wave.from_file("audio.wav") # Works with TorchFX filters filtered = wave | fx.filter.LoButterworth(1000) # Works with torch.nn.Sequential chained = wave | nn.Sequential( fx.filter.HiButterworth(100, order=2), fx.filter.LoButterworth(5000, order=4) ) # Works with any custom nn.Module class CustomGain(nn.Module): def __init__(self, gain: float): super().__init__() self.gain = gain def forward(self, x: torch.Tensor) -> torch.Tensor: return x * self.gain custom = wave | CustomGain(gain=0.5) ``` ```{note} The module's `forward` method should accept a tensor of shape `(channels, time)` and return a tensor of the same or compatible shape. ``` ## Using nn.Sequential TorchFX filters and effects can be composed using {class}`torch.nn.Sequential`, providing an alternative to the pipeline operator for creating processing chains. ### Sequential Composition Pattern {class}`torch.nn.Sequential` creates a container that applies modules in order: ```{mermaid} graph LR Input["Input Tensor
(C, T)"] Sequential["nn.Sequential"] subgraph "Sequential Container" F1["HiChebyshev1(20)"] F2["HiChebyshev1(60)"] F3["HiChebyshev1(65)"] F4["LoButterworth(5000)"] F5["LoButterworth(4900)"] F6["LoButterworth(4850)"] end Output["Output Tensor
(C, T)"] Input --> Sequential Sequential --> F1 F1 --> F2 F2 --> F3 F3 --> F4 F4 --> F5 F5 --> F6 F6 --> Output style Input fill:#e1f5ff style Output fill:#e1f5ff style F1 fill:#fff5e1 style F2 fill:#fff5e1 style F3 fill:#fff5e1 style F4 fill:#fff5e1 style F5 fill:#fff5e1 style F6 fill:#fff5e1 ``` **Sequential Processing Flow** - Audio flows through each module in the container sequentially. ### Three Equivalent Approaches TorchFX provides three equivalent ways to chain filters, each with different trade-offs: ```python import torch.nn as nn import torchfx as fx from torchfx.filter import HiChebyshev1, LoButterworth wave = fx.Wave.from_file("audio.wav") # Approach 1: Custom Module class FilterChain(nn.Module): def __init__(self, fs): super().__init__() self.f1 = HiChebyshev1(20, fs=fs) self.f2 = HiChebyshev1(60, fs=fs) self.f3 = HiChebyshev1(65, fs=fs) self.f4 = LoButterworth(5000, fs=fs) self.f5 = LoButterworth(4900, fs=fs) self.f6 = LoButterworth(4850, fs=fs) def forward(self, x): x = self.f1(x) x = self.f2(x) x = self.f3(x) x = self.f4(x) x = self.f5(x) x = self.f6(x) return x custom_chain = FilterChain(wave.fs) result1 = wave | custom_chain # Approach 2: nn.Sequential seq_chain = nn.Sequential( HiChebyshev1(20, fs=wave.fs), HiChebyshev1(60, fs=wave.fs), HiChebyshev1(65, fs=wave.fs), LoButterworth(5000, fs=wave.fs), LoButterworth(4900, fs=wave.fs), LoButterworth(4850, fs=wave.fs), ) result2 = wave | seq_chain # Approach 3: Pipe Operator result3 = ( wave | HiChebyshev1(20) | HiChebyshev1(60) | HiChebyshev1(65) | LoButterworth(5000) | LoButterworth(4900) | LoButterworth(4850) ) ``` ### Comparison of Approaches | Approach | Pros | Cons | Best For | |----------|------|------|----------| | **Custom Module** | Named attributes, can add logic, best for reuse | More verbose | Reusable components, complex logic | | **nn.Sequential** | Standard PyTorch pattern, works with PyTorch tools | Must specify `fs` manually | PyTorch integration, model composition | | **Pipe Operator** | Most concise, auto `fs` configuration | Less familiar to PyTorch users | Quick prototyping, scripts | ```{tip} Use the pipe operator for exploration and scripts. Use {class}`torch.nn.Sequential` or custom modules when building larger systems or when you need to integrate with PyTorch training pipelines. ``` ## Creating Custom Neural Network Modules TorchFX filters and effects can be embedded in custom {class}`torch.nn.Module` classes to create reusable processing blocks with custom logic. ### Pattern 1: Audio Preprocessing Module Create a preprocessing module for machine learning pipelines: ```python import torch import torch.nn as nn import torchfx as fx class AudioPreprocessor(nn.Module): """Preprocessing module for audio classification.""" def __init__(self, sample_rate: int = 44100): super().__init__() # Filtering layers self.rumble_filter = fx.filter.HiButterworth(cutoff=80, order=2, fs=sample_rate) self.noise_filter = fx.filter.LoButterworth(cutoff=12000, order=4, fs=sample_rate) # Normalization self.normalizer = fx.effect.Normalize(peak=0.9) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Parameters ---------- x : Tensor Audio tensor of shape (batch, channels, time) Returns ------- Tensor Preprocessed audio of shape (batch, channels, time) """ # Process each sample in batch batch_size = x.size(0) results = [] for i in range(batch_size): # Extract single sample sample = x[i] # (channels, time) # Apply filters sample = self.rumble_filter(sample) sample = self.noise_filter(sample) sample = self.normalizer(sample) results.append(sample) return torch.stack(results) # Usage preprocessor = AudioPreprocessor(sample_rate=44100) preprocessor.to("cuda") # Move to GPU # In training loop audio_batch = torch.randn(32, 2, 44100).to("cuda") # (batch, channels, time) preprocessed = preprocessor(audio_batch) ``` ### Pattern 2: Multi-Stage Effects Chain Create a reusable effects chain module: ```python import torch.nn as nn import torchfx as fx class VocalProcessor(nn.Module): """Professional vocal processing chain.""" def __init__(self, sample_rate: int = 44100): super().__init__() # Stage 1: Cleanup self.rumble_removal = fx.filter.HiButterworth(cutoff=80, order=2, fs=sample_rate) self.air_filter = fx.filter.LoButterworth(cutoff=15000, order=4, fs=sample_rate) # Stage 2: Tonal shaping (using Sequential) self.eq_chain = nn.Sequential( fx.filter.PeakingEQ(freq=200, gain_db=-2, q=0.7, fs=sample_rate), # Reduce mud fx.filter.PeakingEQ(freq=3000, gain_db=3, q=1.0, fs=sample_rate), # Presence fx.filter.PeakingEQ(freq=10000, gain_db=2, q=0.7, fs=sample_rate), # Brightness ) # Stage 3: Dynamics and final polish self.compressor = fx.effect.Compressor(threshold=0.5, ratio=4.0) self.limiter = fx.effect.Normalize(peak=0.95) def forward(self, x: torch.Tensor) -> torch.Tensor: # Stage 1 x = self.rumble_removal(x) x = self.air_filter(x) # Stage 2 x = self.eq_chain(x) # Stage 3 x = self.compressor(x) x = self.limiter(x) return x # Usage with Wave wave = fx.Wave.from_file("vocal.wav") processor = VocalProcessor(wave.fs) processed = wave | processor processed.save("processed_vocal.wav") ``` ### Pattern 3: Parameterized Filter Bank Create a module with learnable or configurable parameters: ```python import torch import torch.nn as nn import torchfx as fx class MultiFrequencyFilter(nn.Module): """Parallel filter bank with multiple cutoff frequencies.""" def __init__(self, cutoff_freqs: list[float], sample_rate: int = 44100): super().__init__() self.filters = nn.ModuleList([ fx.filter.LoButterworth(cutoff=freq, order=4, fs=sample_rate) for freq in cutoff_freqs ]) # Learnable weights for each filter (optional) self.weights = nn.Parameter(torch.ones(len(cutoff_freqs))) def forward(self, x: torch.Tensor) -> torch.Tensor: # Apply all filters and combine with learned weights outputs = [] for i, filter in enumerate(self.filters): filtered = filter(x) weighted = filtered * self.weights[i] outputs.append(weighted) # Sum all weighted outputs return sum(outputs) # Usage filter_bank = MultiFrequencyFilter( cutoff_freqs=[500, 1000, 2000, 4000], sample_rate=44100 ) wave = fx.Wave.from_file("audio.wav") result = wave | filter_bank ``` ### Key Module Characteristics | Property | Behavior | |----------|----------| | **Module registration** | Filters assigned as attributes are automatically registered by PyTorch | | **Parameter tracking** | Filter coefficients become part of the module's parameters | | **Device management** | Calling `.to(device)` on the parent module moves all child filters | | **State dict** | Filter states are included in `state_dict()` for serialization | | **Nested modules** | Modules can contain other modules arbitrarily deep | ```{seealso} [PyTorch Module Documentation](https://pytorch.org/docs/stable/notes/modules.html) - Official guide to building custom modules ``` ## Gradient Computation and Differentiability TorchFX operations maintain gradient flow where applicable, enabling their use in differentiable audio processing and neural network training. ### Differentiability Status ```{mermaid} graph TD Operations["TorchFX Operations"] subgraph "Differentiable Operations" Filters["Filter Forward Pass
IIR and FIR convolutions
Gradients flow through"] Gain["Gain Effect
Amplitude scaling
Gradients flow through"] Transform["Wave.transform
Functional operations
Conditional on function"] end subgraph "Non-Differentiable Operations" FileIO["File I/O
from_file, save
No gradients"] CoeffComp["Coefficient Computation
compute_coefficients()
Uses SciPy, no gradients"] Design["Filter Design
Cutoff/order to coefficients
No gradients"] end Operations --> Filters Operations --> Gain Operations --> Transform Operations --> FileIO Operations --> CoeffComp Operations --> Design Filters -->|"backward()"| Backprop["Backpropagation"] Gain -->|"backward()"| Backprop Transform -->|"if func differentiable"| Backprop style Filters fill:#e1ffe1 style Gain fill:#e1ffe1 style Transform fill:#e1ffe1 style FileIO fill:#ffe1e1 style CoeffComp fill:#ffe1e1 style Design fill:#ffe1e1 ``` **Differentiability Map** - Shows which TorchFX operations support gradient computation. ### Gradient Flow Through Processing Chain When using TorchFX modules in a training pipeline: 1. **Forward Pass**: Audio tensors flow through filter convolutions and effects 2. **Backward Pass**: Gradients flow back through differentiable operations 3. **Parameter Updates**: Upstream parameters receive gradients; filter coefficients remain fixed ```{mermaid} sequenceDiagram participant Upstream as "Upstream Layer
(learnable)" participant Filter as "TorchFX Filter
(fixed coefficients)" participant Loss as "Loss Function" Note over Upstream,Loss: Forward Pass Upstream->>Filter: x (requires_grad=True) Filter->>Filter: Apply filter convolution Filter->>Loss: filtered_x (grad_fn=<...>) Loss->>Loss: Compute loss Note over Upstream,Loss: Backward Pass Loss->>Loss: loss.backward() Loss->>Filter: Gradients for filtered_x Filter->>Filter: Compute gradients wrt input Filter->>Upstream: Gradients for x Upstream->>Upstream: Update parameters Note over Filter: Filter coefficients
do NOT receive gradients ``` **Gradient Flow in Training** - Gradients flow through TorchFX filters but don't update filter coefficients. ### Important Notes on Gradients ```{warning} **Filter coefficients are NOT learnable parameters.** Filter coefficients are computed from design parameters (cutoff frequency, order) using non-differentiable SciPy functions. The coefficients themselves do not receive gradients during backpropagation. If you need learnable filtering, consider using learnable FIR filters where the filter taps are {class}`torch.nn.Parameter` objects. ``` ### Example: Differentiable Audio Augmentation ```python import torch import torch.nn as nn import torchfx as fx class AudioClassifier(nn.Module): """Example classifier with TorchFX augmentation.""" def __init__(self, num_classes: int = 10): super().__init__() # Fixed augmentation filters self.augment = nn.Sequential( fx.filter.HiButterworth(cutoff=80, order=2, fs=44100), fx.filter.LoButterworth(cutoff=12000, order=4, fs=44100), ) # Learnable classification layers self.conv1 = nn.Conv1d(2, 64, kernel_size=3) self.conv2 = nn.Conv1d(64, 128, kernel_size=3) self.fc = nn.Linear(128, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (batch, channels, time) # Augmentation (gradients flow through) batch_results = [] for i in range(x.size(0)): augmented = self.augment(x[i]) batch_results.append(augmented) x = torch.stack(batch_results) # Classification (learnable) x = self.conv1(x) x = torch.relu(x) x = self.conv2(x) x = torch.relu(x) x = x.mean(dim=-1) # Global average pooling x = self.fc(x) return x # Training example model = AudioClassifier(num_classes=10) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Forward pass audio_batch = torch.randn(32, 2, 44100, requires_grad=True) labels = torch.randint(0, 10, (32,)) # Compute loss outputs = model(audio_batch) loss = nn.CrossEntropyLoss()(outputs, labels) # Backward pass (gradients flow through filters to conv layers) loss.backward() # Update ONLY the learnable parameters (conv1, conv2, fc) optimizer.step() # Filter coefficients remain unchanged ``` ```{tip} TorchFX filters work great as **fixed augmentation layers** in neural network pipelines, providing consistent audio preprocessing that gradients can flow through. ``` ## Mixing with torchaudio.transforms TorchFX modules can be mixed with [torchaudio](https://pytorch.org/audio/) transforms in the same processing pipeline, leveraging the best of both libraries. ### Integration Pattern ```python import torch import torchfx as fx import torchaudio.transforms as T wave = fx.Wave.from_file("audio.wav") # Mix TorchFX filters with torchaudio transforms processed = ( wave | fx.filter.LoButterworth(100, order=2) # TorchFX filter | fx.filter.HiButterworth(2000, order=2) # TorchFX filter | T.Vol(gain=0.5) # torchaudio volume | fx.effect.Normalize(peak=0.9) # TorchFX effect ) processed.save("mixed_processing.wav") ``` ### Compatible torchaudio Transforms | Transform | Use Case | Integration | |-----------|----------|-------------| | `T.Vol` | Volume adjustment | Direct pipe operator | | `T.Resample` | Sample rate conversion | Direct pipe operator | | `T.Fade` | Fade in/out | Direct pipe operator | | `T.FrequencyMasking` | Spectrogram augmentation | Requires spectrogram conversion | | `T.TimeMasking` | Spectrogram augmentation | Requires spectrogram conversion | | `T.MelScale` | Mel-frequency processing | Requires spectrogram conversion | ### Example: Complete Audio Pipeline ```python import torch import torch.nn as nn import torchfx as fx import torchaudio.transforms as T class AudioPipeline(nn.Module): """Complete audio processing pipeline mixing TorchFX and torchaudio.""" def __init__(self, sample_rate: int = 44100): super().__init__() # TorchFX preprocessing self.lowpass = fx.filter.LoButterworth(cutoff=100, order=2, fs=sample_rate) # torchaudio transforms self.resample = T.Resample(orig_freq=sample_rate, new_freq=16000) self.volume = T.Vol(gain=0.8, gain_type="amplitude") # TorchFX final processing self.normalize = fx.effect.Normalize(peak=0.95) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.lowpass(x) x = self.resample(x) x = self.volume(x) x = self.normalize(x) return x # Usage pipeline = AudioPipeline(sample_rate=44100) wave = fx.Wave.from_file("audio.wav") # Apply pipeline result = wave | pipeline ``` ### Module Compatibility Requirements Any {class}`torch.nn.Module` is compatible with TorchFX's pipe operator if it: 1. **Implements** a `forward` method 2. **Accepts** a tensor as input with shape `(channels, time)` for audio 3. **Returns** a tensor of the same or compatible shape 4. **Operates** on the same device as the input tensor ```{note} Some torchaudio transforms expect mono audio or specific channel configurations. Check the transform documentation and adjust channel count if needed using {meth}`~torchfx.Wave.to_mono()` or other methods. ``` ## Complete Working Examples ### Example 1: Audio Classification with TorchFX Augmentation ```python import torch import torch.nn as nn import torchfx as fx from torch.utils.data import Dataset, DataLoader class AudioDataset(Dataset): """Simple audio dataset.""" def __init__(self, file_paths: list[str], labels: list[int]): self.file_paths = file_paths self.labels = labels def __len__(self): return len(self.file_paths) def __getitem__(self, idx): wave = fx.Wave.from_file(self.file_paths[idx]) # Ensure fixed length (e.g., 3 seconds at 44100 Hz) target_length = 3 * 44100 if wave.ys.size(-1) > target_length: wave.ys = wave.ys[..., :target_length] else: # Pad if too short padding = target_length - wave.ys.size(-1) wave.ys = torch.nn.functional.pad(wave.ys, (0, padding)) return wave.ys, self.labels[idx] class AudioClassifierWithAugmentation(nn.Module): """Classifier with TorchFX augmentation.""" def __init__(self, num_classes: int = 10): super().__init__() # Data augmentation (fixed filters) self.augmentation = nn.Sequential( fx.filter.HiButterworth(cutoff=80, order=2, fs=44100), fx.filter.LoButterworth(cutoff=15000, order=4, fs=44100), fx.effect.Normalize(peak=0.9), ) # Feature extraction self.conv1 = nn.Conv1d(2, 32, kernel_size=3, stride=2) self.conv2 = nn.Conv1d(32, 64, kernel_size=3, stride=2) self.pool = nn.AdaptiveAvgPool1d(1) # Classification self.fc = nn.Linear(64, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (batch, channels, time) # Apply augmentation to each sample batch_size = x.size(0) augmented = [] for i in range(batch_size): aug = self.augmentation(x[i]) augmented.append(aug) x = torch.stack(augmented) # Feature extraction x = torch.relu(self.conv1(x)) x = torch.relu(self.conv2(x)) x = self.pool(x).squeeze(-1) # Classification x = self.fc(x) return x # Training setup device = "cuda" if torch.cuda.is_available() else "cpu" # Create dataset (example) file_paths = ["audio1.wav", "audio2.wav", "audio3.wav"] labels = [0, 1, 2] dataset = AudioDataset(file_paths, labels) dataloader = DataLoader(dataset, batch_size=4, shuffle=True) # Initialize model model = AudioClassifierWithAugmentation(num_classes=3).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() # Training loop (one epoch) model.train() for batch_audio, batch_labels in dataloader: batch_audio = batch_audio.to(device) batch_labels = batch_labels.to(device) # Forward pass outputs = model(batch_audio) loss = criterion(outputs, batch_labels) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() print(f"Loss: {loss.item():.4f}") ``` ### Example 2: Custom Effect Module with State ```python import torch import torch.nn as nn import torchfx as fx class StatefulReverb(nn.Module): """Custom reverb with learnable parameters.""" def __init__(self, max_delay: int = 44100): super().__init__() # Learnable delay time (in samples) self.delay_time = nn.Parameter(torch.tensor(22050.0)) # Learnable decay factor self.decay = nn.Parameter(torch.tensor(0.5)) self.max_delay = max_delay def forward(self, x: torch.Tensor) -> torch.Tensor: # Clamp delay to valid range delay = torch.clamp(self.delay_time, 0, self.max_delay).long() # Simple delay-based reverb output = x.clone() if delay > 0: # Add delayed signal with decay padded = torch.nn.functional.pad(x, (delay, 0)) delayed = padded[..., :-delay] if delay > 0 else padded output = x + delayed * self.decay return output class ProcessorWithLearnableReverb(nn.Module): """Combines fixed filters with learnable reverb.""" def __init__(self, sample_rate: int = 44100): super().__init__() # Fixed preprocessing self.preprocess = nn.Sequential( fx.filter.HiButterworth(cutoff=80, order=2, fs=sample_rate), fx.filter.LoButterworth(cutoff=12000, order=4, fs=sample_rate), ) # Learnable reverb self.reverb = StatefulReverb(max_delay=sample_rate) # Fixed normalization self.normalize = fx.effect.Normalize(peak=0.9) def forward(self, x: torch.Tensor) -> torch.Tensor: # Fixed filtering x = self.preprocess(x) # Learnable reverb x = self.reverb(x) # Fixed normalization x = self.normalize(x) return x # Usage processor = ProcessorWithLearnableReverb(sample_rate=44100) wave = fx.Wave.from_file("audio.wav") # Process (can be used in training loop) processed = wave | processor # Access learnable parameters print(f"Delay time: {processor.reverb.delay_time.item():.0f} samples") print(f"Decay factor: {processor.reverb.decay.item():.3f}") ``` ### Example 3: Mixing TorchFX with torchaudio in Sequential ```python import torch.nn as nn import torchfx as fx import torchaudio.transforms as T class HybridAudioProcessor(nn.Module): """Processor combining TorchFX and torchaudio transforms.""" def __init__(self, sample_rate: int = 44100, target_sr: int = 16000): super().__init__() self.pipeline = nn.Sequential( # TorchFX: Remove low-frequency rumble fx.filter.HiButterworth(cutoff=80, order=2, fs=sample_rate), # TorchFX: Remove high-frequency noise fx.filter.LoButterworth(cutoff=15000, order=4, fs=sample_rate), # torchaudio: Resample to lower rate T.Resample(orig_freq=sample_rate, new_freq=target_sr), # torchaudio: Adjust volume T.Vol(gain=0.8, gain_type="amplitude"), # TorchFX: Final normalization fx.effect.Normalize(peak=0.95), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.pipeline(x) # Usage processor = HybridAudioProcessor(sample_rate=44100, target_sr=16000) wave = fx.Wave.from_file("audio.wav") # Process result = wave | processor # Note: Output has different sample rate! # You need to create a new Wave with the correct fs result_wave = fx.Wave(result.ys if isinstance(result, fx.Wave) else processor(wave.ys), fs=16000) result_wave.save("processed.wav") ``` ## Architecture Diagrams ### TorchFX in a Training Pipeline ```{mermaid} graph TB subgraph "Data Loading" Files["Audio Files"] --> Loader["DataLoader
(PyTorch)"] end subgraph "Model Architecture" Loader --> Aug["Augmentation Layer
(TorchFX filters - fixed)"] Aug --> Conv1["Conv1D Layer
(learnable)"] Conv1 --> Conv2["Conv1D Layer
(learnable)"] Conv2 --> Pool["Pooling Layer"] Pool --> FC["Fully Connected
(learnable)"] end subgraph "Training Loop" FC --> Loss["Loss Function"] Loss --> Backward["Backward Pass"] Backward --> Optimizer["Optimizer
(updates Conv + FC)"] end Optimizer -.->|"gradients flow
but don't update"| Aug Optimizer -->|"updates parameters"| Conv1 Optimizer -->|"updates parameters"| Conv2 Optimizer -->|"updates parameters"| FC style Aug fill:#fff5e1 style Conv1 fill:#e1ffe1 style Conv2 fill:#e1ffe1 style FC fill:#e1ffe1 style Loss fill:#ffe1e1 ``` **Training Pipeline Architecture** - TorchFX filters serve as fixed augmentation layers while gradients flow through to update learnable layers. ### Module Composition Patterns ```{mermaid} graph TB subgraph "Pattern 1: Sequential Chain" S1["nn.Sequential"] --> SF1["TorchFX Filter 1"] SF1 --> SF2["TorchFX Filter 2"] SF2 --> SF3["TorchFX Effect"] end subgraph "Pattern 2: Custom Module" C1["Custom Module"] --> CM1["def __init__"] CM1 --> CF1["self.filter1 = ..."] CM1 --> CF2["self.filter2 = ..."] C1 --> CM2["def forward"] CM2 --> CL1["x = self.filter1(x)"] CL1 --> CL2["x = self.filter2(x)"] end subgraph "Pattern 3: Hybrid Module" H1["Hybrid Module"] --> HT1["TorchFX Filters"] H1 --> HT2["torchaudio Transforms"] H1 --> HT3["Custom Logic"] HT1 --> HF["forward()"] HT2 --> HF HT3 --> HF end style S1 fill:#e1f5ff style C1 fill:#e1f5ff style H1 fill:#e1f5ff ``` **Module Composition Patterns** - Three common ways to structure TorchFX modules in larger systems. ## Best Practices ### Use nn.Sequential for Simple Chains ```python # ✅ GOOD: Clear, standard PyTorch pattern preprocessing = nn.Sequential( fx.filter.HiButterworth(cutoff=80, order=2, fs=44100), fx.filter.LoButterworth(cutoff=12000, order=4, fs=44100), fx.effect.Normalize(peak=0.9), ) # ❌ LESS GOOD: Custom module for simple chain class Preprocessing(nn.Module): def __init__(self): super().__init__() self.f1 = fx.filter.HiButterworth(cutoff=80, order=2, fs=44100) self.f2 = fx.filter.LoButterworth(cutoff=12000, order=4, fs=44100) self.norm = fx.effect.Normalize(peak=0.9) def forward(self, x): return self.norm(self.f2(self.f1(x))) ``` ### Set Sample Rate Explicitly for Reusable Modules ```python # ✅ GOOD: Sample rate specified at module creation class ReusableFilter(nn.Module): def __init__(self, sample_rate: int = 44100): super().__init__() self.filter = fx.filter.LoButterworth(cutoff=1000, order=4, fs=sample_rate) # ❌ BAD: Relying on pipe operator to set fs class ReusableFilter(nn.Module): def __init__(self): super().__init__() self.filter = fx.filter.LoButterworth(cutoff=1000, order=4) # fs=None! ``` ### Move Entire Modules to Device Together ```python # ✅ GOOD: Move entire module at once processor = VocalProcessor(sample_rate=44100) processor.to("cuda") # Moves all child modules # ❌ BAD: Moving individual components processor = VocalProcessor(sample_rate=44100) processor.filter1.to("cuda") processor.filter2.to("cuda") processor.effect1.to("cuda") # Easy to miss one! ``` ### Use ModuleList for Dynamic Filter Collections ```python # ✅ GOOD: ModuleList for dynamic collections class MultiFilterBank(nn.Module): def __init__(self, cutoffs: list[float], fs: int = 44100): super().__init__() self.filters = nn.ModuleList([ fx.filter.LoButterworth(cutoff=f, order=4, fs=fs) for f in cutoffs ]) def forward(self, x): return sum(f(x) for f in self.filters) # ❌ BAD: Regular list (filters won't be registered!) class MultiFilterBank(nn.Module): def __init__(self, cutoffs: list[float], fs: int = 44100): super().__init__() self.filters = [ # ⚠️ Regular list! fx.filter.LoButterworth(cutoff=f, order=4, fs=fs) for f in cutoffs ] ``` ## Common Pitfalls ### Pitfall 1: Forgetting to Process Batches Correctly ```python # ❌ WRONG: TorchFX filters expect (channels, time), not (batch, channels, time) class BadProcessor(nn.Module): def forward(self, x): # x is (batch, channels, time) return self.filter(x) # Error! Filter expects (channels, time) # ✅ CORRECT: Process each sample in batch class GoodProcessor(nn.Module): def forward(self, x): # x is (batch, channels, time) results = [self.filter(x[i]) for i in range(x.size(0))] return torch.stack(results) ``` ### Pitfall 2: Mixing Sample Rates ```python # ❌ WRONG: Sample rate mismatch audio_44k = fx.Wave.from_file("audio_44100.wav") # 44100 Hz filter_16k = fx.filter.LoButterworth(cutoff=1000, fs=16000) # Wrong fs! result = audio_44k | filter_16k # Incorrect filtering! # ✅ CORRECT: Match sample rates audio_44k = fx.Wave.from_file("audio_44100.wav") filter_44k = fx.filter.LoButterworth(cutoff=1000, fs=44100) result = audio_44k | filter_44k ``` ### Pitfall 3: Expecting Filter Coefficients to Update ```python # ❌ WRONG: Expecting filter coefficients to learn model = nn.Sequential( fx.filter.LoButterworth(cutoff=1000, fs=44100), SomeLearnableLayer(), ) # Training will NOT update the filter's cutoff frequency or coefficients! # ✅ CORRECT: Use fixed filters or create learnable FIR filters # Option 1: Use as fixed preprocessing preprocessing = fx.filter.LoButterworth(cutoff=1000, fs=44100) # Option 2: Create learnable FIR filter class LearnableFIR(nn.Module): def __init__(self, num_taps: int): super().__init__() self.taps = nn.Parameter(torch.randn(num_taps)) def forward(self, x): return torch.nn.functional.conv1d( x.unsqueeze(0), self.taps.unsqueeze(0).unsqueeze(0), padding=len(self.taps)//2 ).squeeze(0) ``` ## Related Concepts - {doc}`gpu-acceleration` - Using TorchFX with CUDA - {doc}`../tutorials/series-parallel-filters` - Combining filters in complex networks - {doc}`../core-concepts/fx` - Understanding the FX base class - {doc}`../tutorials/custom-effects` - Creating custom effects as nn.Modules ## External Resources - [PyTorch nn.Module Tutorial](https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_module.html) - Building custom modules - [torchaudio Documentation](https://pytorch.org/audio/stable/index.html) - Audio processing in PyTorch - [PyTorch Device Management](https://pytorch.org/docs/stable/notes/cuda.html) - Working with GPUs - [Automatic Differentiation](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html) - Understanding gradients in PyTorch ## Summary Key takeaways for PyTorch integration: 1. **Module Inheritance**: All TorchFX effects and filters inherit from {class}`torch.nn.Module`, enabling full PyTorch compatibility 2. **Pipe Operator**: The `|` operator accepts any {class}`torch.nn.Module`, not just TorchFX components 3. **Sequential Composition**: Use {class}`torch.nn.Sequential` for standard PyTorch-style chains 4. **Custom Modules**: Embed TorchFX filters in custom modules for reusable processing blocks 5. **Gradient Flow**: Gradients flow through TorchFX operations, but filter coefficients are fixed (non-learnable) 6. **Library Mixing**: Seamlessly combine TorchFX with torchaudio transforms and custom modules TorchFX's deep integration with PyTorch makes it a natural fit for machine learning pipelines, data augmentation, and differentiable audio processing workflows.