(multi-channel)= # Multi-Channel Processing Learn how to process multi-channel audio in TorchFX, from simple stereo to complex surround sound configurations. This guide covers tensor shape conventions, per-channel vs. cross-channel processing strategies, and production-ready patterns for building multi-channel effects. ## Prerequisites Before starting this guide, you should be familiar with: - {doc}`../core-concepts/wave` - Wave class and tensor handling - {doc}`../core-concepts/fx` - FX base class for effects - {doc}`../tutorials/custom-effects` - Creating custom effects - Basic PyTorch {class}`torch.nn.Module` and {class}`torch.nn.ModuleList` usage ## Overview Multi-channel audio is everywhere in modern audio production—from stereo music to 5.1/7.1 surround sound in film and immersive spatial audio formats. TorchFX provides flexible patterns for handling multi-channel audio with three main processing approaches: | Processing Type | Description | Example Use Case | |----------------|-------------|------------------| | **Broadcast** | Same effect applied to all channels | Global EQ, normalization | | **Per-channel** | Different processing for each channel | Independent channel EQ, stereo mastering | | **Cross-channel** | Channels interact with each other | Ping-pong delay, stereo widening | ```{mermaid} graph TB subgraph "Multi-Channel Processing Strategies" Input[Multi-Channel Input
channels, time] subgraph Broadcast[Broadcast Processing] BC[Same filter
applied to all channels] end subgraph PerChannel[Per-Channel Processing] PC1[Channel 1: Filter A] PC2[Channel 2: Filter B] PCN[Channel N: Filter C] end subgraph CrossChannel[Cross-Channel Processing] CC[Channels interact
e.g., L→R, R→L] end Input --> Broadcast Input --> PerChannel Input --> CrossChannel Broadcast --> Output1[Output: Same processing] PerChannel --> Output2[Output: Different per channel] CrossChannel --> Output3[Output: Channel interaction] end style Input fill:#e1f5ff style Broadcast fill:#e8f5e1 style PerChannel fill:#fff5e1 style CrossChannel fill:#f5e1ff ``` ## Tensor Shape Conventions TorchFX follows standard PyTorch audio conventions where **the last dimension represents time** and earlier dimensions represent channels and/or batches. ### Standard Audio Shapes | Shape | Description | Example Use Case | |-------|-------------|------------------| | `(time,)` | Mono audio | Single microphone recording | | `(channels, time)` | Multi-channel audio | Stereo (2), 5.1 surround (6), 7.1 surround (8) | | `(batch, channels, time)` | Batched multi-channel | Neural network training batches | | `(..., time)` | Arbitrary leading dimensions | Generic tensor processing | All effects inheriting from {class}`~torchfx.FX` accept tensors with these shapes. By default, effects broadcast operations across all dimensions except time, unless they implement channel-specific logic. ```{important} The **last dimension is always time**. This is critical for proper tensor handling in TorchFX and PyTorch audio libraries. ``` ### Shape Convention Examples ```python import torch import torchfx as fx # Mono audio: (time,) mono = torch.randn(44100) # 1 second at 44.1kHz print(f"Mono shape: {mono.shape}") # (44100,) # Stereo audio: (channels, time) stereo = torch.randn(2, 44100) # 2 channels, 1 second print(f"Stereo shape: {stereo.shape}") # (2, 44100) # 5.1 surround: (channels, time) # Order: Front L, Front R, Center, LFE, Rear L, Rear R surround_51 = torch.randn(6, 44100) print(f"5.1 shape: {surround_51.shape}") # (6, 44100) # 7.1 surround: (channels, time) surround_71 = torch.randn(8, 44100) print(f"7.1 shape: {surround_71.shape}") # (8, 44100) # Batched stereo for ML: (batch, channels, time) batch = torch.randn(32, 2, 44100) # 32 stereo samples print(f"Batched shape: {batch.shape}") # (32, 2, 44100) ``` ### Channel Processing Flow ```{mermaid} graph LR subgraph "Input Tensor Shapes" Mono["(time,)
Mono"] Stereo["(2, time)
Stereo"] Multi["(N, time)
N channels"] Batch["(B, N, time)
Batched"] end subgraph "FX Base Class" FX["FX.forward(x: Tensor)
All shapes supported"] end subgraph "Processing Strategies" Broadcast["Broadcast
Same to all"] PerCh["Per-Channel
Independent"] Cross["Cross-Channel
Interactive"] end Mono --> FX Stereo --> FX Multi --> FX Batch --> FX FX --> Broadcast FX --> PerCh FX --> Cross ``` ## Per-Channel Processing Patterns Per-channel processing applies different effects to each channel independently. This is the most common pattern for multi-channel effects. ### Pattern 1: Fixed Channel Count with ModuleList Use {class}`torch.nn.ModuleList` to store separate processing chains for each channel: ```python import torch from torch import Tensor, nn import torchaudio.transforms as T from torchfx import FX, Wave from torchfx.filter import iir class StereoProcessor(FX): """Apply different processing to left and right channels. Left channel: Remove rumble, gentle high-pass Right channel: Different EQ curve with volume reduction Parameters ---------- fs : int, optional Sample rate in Hz. Can be set via Wave pipeline. Examples -------- >>> wave = Wave.from_file("stereo.wav") >>> processor = StereoProcessor(fs=wave.fs) >>> processed = wave | processor """ def __init__(self, fs: int | None = None): super().__init__() self.fs = fs # Store per-channel processing chains self.channels = nn.ModuleList([ self.left_channel(), self.right_channel(), ]) def left_channel(self) -> nn.Sequential: """Processing chain for left channel.""" return nn.Sequential( iir.HiButterworth(cutoff=100, order=2, fs=self.fs), # Remove rumble iir.LoButterworth(cutoff=8000, order=4, fs=self.fs), # Gentle rolloff ) def right_channel(self) -> nn.Sequential: """Processing chain for right channel.""" return nn.Sequential( iir.HiButterworth(cutoff=150, order=2, fs=self.fs), # Higher HPF iir.LoButterworth(cutoff=10000, order=4, fs=self.fs), # Different LPF T.Vol(0.9), # Slight volume reduction ) def forward(self, x: Tensor) -> Tensor: """Apply per-channel processing. Parameters ---------- x : Tensor Input audio with shape (2, time) for stereo Returns ------- Tensor Processed audio with same shape as input """ if self.fs is None: raise ValueError("Sample rate (fs) must be set before processing") # Process each channel independently for i in range(len(self.channels)): x[i] = self.channels[i](x[i]) return x # Usage wave = Wave.from_file("stereo_music.wav") processor = StereoProcessor(fs=wave.fs) processed = wave | processor processed.save("stereo_processed.wav") ``` ```{mermaid} graph TB Input[Stereo Input
2, time] subgraph Processor[StereoProcessor] subgraph Left[Left Channel Chain] LH[HiButterworth
100 Hz] LL[LoButterworth
8000 Hz] end subgraph Right[Right Channel Chain] RH[HiButterworth
150 Hz] RL[LoButterworth
10000 Hz] RV[Vol 0.9] end end Output[Stereo Output
2, time] Input -->|x[0]| Left Input -->|x[1]| Right LH --> LL RH --> RL --> RV Left -->|processed[0]| Output Right -->|processed[1]| Output style Input fill:#e1f5ff style Output fill:#e1f5ff style Left fill:#fff5e1 style Right fill:#ffe1e1 ``` **Key implementation details**: 1. Use {class}`torch.nn.ModuleList` to properly register submodules 2. Accept `fs` parameter and pass it to filters requiring sample rate 3. Validate `fs` is set before processing 4. Process each channel independently in `forward()` ```{seealso} {doc}`../tutorials/custom-effects` - General patterns for creating custom effects ``` ### Pattern 2: Dynamic Channel Count Handle any number of channels dynamically by creating processing chains on the fly: ```python from torch import Tensor, nn from torchfx import FX, Wave from torchfx.filter import iir class FlexibleMultiChannel(FX): """Effect that adapts to any number of channels. Creates independent processing chains for each channel, with frequency ranges scaled based on channel index. Parameters ---------- fs : int, optional Sample rate in Hz Examples -------- >>> # Works with stereo >>> stereo_wave = Wave.from_file("stereo.wav") >>> fx = FlexibleMultiChannel(fs=stereo_wave.fs) >>> result = stereo_wave | fx >>> # Also works with 5.1 surround >>> surround_wave = Wave.from_file("surround_51.wav") >>> result = surround_wave | fx """ def __init__(self, fs: int | None = None): super().__init__() self.fs = fs self.channels = None # Created dynamically on first forward pass def _create_channels(self, num_channels: int): """Create processing chains for given number of channels. Each channel gets a bandpass filter with different frequency range. """ self.channels = nn.ModuleList([ nn.Sequential( iir.HiButterworth(cutoff=100 * (i + 1), order=2, fs=self.fs), iir.LoButterworth(cutoff=1000 * (i + 1), order=2, fs=self.fs), ) for i in range(num_channels) ]) def forward(self, x: Tensor) -> Tensor: """Process audio with dynamic channel adaptation.""" if self.fs is None: raise ValueError("Sample rate (fs) must be set") # Determine number of channels num_channels = x.shape[0] if x.ndim >= 2 else 1 # Create channels on first forward pass if self.channels is None: self._create_channels(num_channels) # Process each channel if x.ndim >= 2: for i in range(num_channels): x[i] = self.channels[i](x[i]) else: # Handle mono input x = self.channels[0](x) return x # Usage with different channel counts # Stereo stereo = Wave.from_file("stereo.wav") fx_stereo = FlexibleMultiChannel(fs=stereo.fs) result_stereo = stereo | fx_stereo # 5.1 Surround surround = Wave.from_file("surround_51.wav") fx_surround = FlexibleMultiChannel(fs=surround.fs) result_surround = surround | fx_surround ``` ### Pattern 3: Complete Production Example Here's a complete, production-ready multi-channel effect based on the TorchFX examples: ```python import torch from torch import Tensor, nn import torchaudio.transforms as T from torchfx import FX, Wave from torchfx.filter import HiButterworth, LoButterworth class ComplexEffect(FX): """Multi-channel effect with different processing per channel. This effect demonstrates a complete production pattern for multi-channel processing with independent channel chains. Channel 1: Bandpass 1000-2000 Hz (mid-range focus) Channel 2: Bandpass 2000-4000 Hz with 50% volume (presence range) Parameters ---------- num_channels : int Number of channels to process (typically 2 for stereo) fs : int, optional Sample rate in Hz. Can be set automatically via Wave pipeline. Examples -------- >>> # Basic usage with Wave pipeline >>> wave = Wave.from_file("stereo.wav") >>> fx = ComplexEffect(num_channels=2, fs=wave.fs) >>> processed = wave | fx >>> processed.save("output.wav") >>> # With GPU acceleration >>> wave = Wave.from_file("stereo.wav").to("cuda") >>> fx = ComplexEffect(num_channels=2, fs=wave.fs).to("cuda") >>> processed = wave | fx >>> processed.to("cpu").save("output.wav") """ def __init__(self, num_channels: int, fs: int | None = None): super().__init__() self.num_channels = num_channels self.fs = fs # Per-channel processing chains stored in ModuleList self.ch = nn.ModuleList([ self.channel1(), self.channel2(), ]) def channel1(self) -> nn.Sequential: """Processing chain for channel 1. Creates a bandpass filter focusing on mid-range frequencies. """ return nn.Sequential( HiButterworth(1000, fs=self.fs), # High-pass at 1 kHz LoButterworth(2000, fs=self.fs), # Low-pass at 2 kHz ) def channel2(self) -> nn.Sequential: """Processing chain for channel 2. Creates a bandpass filter focusing on presence range with volume reduction. """ return nn.Sequential( HiButterworth(2000, fs=self.fs), # High-pass at 2 kHz LoButterworth(4000, fs=self.fs), # Low-pass at 4 kHz T.Vol(0.5), # Reduce volume by 50% ) def forward(self, x: Tensor) -> Tensor: """Apply per-channel processing. Parameters ---------- x : Tensor Input audio with shape (num_channels, time) Returns ------- Tensor Processed audio with same shape as input Raises ------ ValueError If sample rate (fs) has not been set """ if self.fs is None: raise ValueError("Sampling frequency (fs) must be set") # Process each channel independently for i in range(self.num_channels): x[i] = self.ch[i](x[i]) return x # Complete usage example if __name__ == "__main__": # Load stereo audio wave = Wave.from_file("input_stereo.wav") print(f"Loaded: {wave.ys.shape}, fs={wave.fs}") # Create and apply effect fx = ComplexEffect(num_channels=2, fs=wave.fs) result = wave | fx # Save result result.save("output_processed.wav") print("Processing complete!") ``` ```{tip} Use the {class}`torch.nn.ModuleList` pattern even if all channels have the same processing. This keeps your code flexible and makes it easy to customize individual channels later. ``` ## Cross-Channel Processing Patterns Cross-channel processing enables channels to interact, creating spatial effects and channel-aware processing. ### Built-in Strategy: Ping-Pong Delay The {class}`~torchfx.effect.PingPongDelayStrategy` creates alternating delays between stereo channels: ```python import torchfx as fx from torchfx.effect import Delay, PingPongDelayStrategy # Load stereo audio wave = fx.Wave.from_file("stereo.wav") # Create ping-pong delay effect delay = Delay( bpm=120, delay_time="1/8", # 8th note delay feedback=0.5, # 50% feedback mix=0.4, # 40% wet signal strategy=PingPongDelayStrategy() ) # Apply effect result = wave | delay result.save("pingpong_delayed.wav") ``` **Ping-pong delay pattern**: ```{mermaid} sequenceDiagram participant L as Left Channel participant R as Right Channel Note over L,R: Original Signal L->>L: Original left audio R->>R: Original right audio Note over L,R: Tap 1 (100% amplitude) L->>R: Left delays into Right Note over L,R: Tap 2 (feedback^1) R->>L: Right delays into Left Note over L,R: Tap 3 (feedback^2) L->>R: Left delays into Right Note over L,R: Tap 4 (feedback^3) R->>L: Right delays into Left Note over L,R: Result: Ping-pong stereo pattern ``` ```{note} If the input is not stereo (not exactly 2 channels), {class}`~torchfx.effect.PingPongDelayStrategy` automatically falls back to {class}`~torchfx.effect.MonoDelayStrategy`. ``` ### Custom Cross-Channel Effect: Stereo Widener Create a custom cross-channel effect using Mid/Side processing: ```python from torch import Tensor import torch from torchfx import FX, Wave class StereoWidener(FX): """Widen stereo image using Mid/Side processing. Converts stereo L/R to Mid/Side, scales the Side component, then converts back to L/R. Parameters ---------- width : float Stereo width multiplier: - 1.0 = no change (original stereo) - >1.0 = wider stereo image - <1.0 = narrower stereo image - 0.0 = pure mono (no stereo) Examples -------- >>> # Widen stereo image by 50% >>> wave = Wave.from_file("stereo.wav") >>> widener = StereoWidener(width=1.5) >>> wider = wave | widener >>> # Narrow to 50% stereo width >>> narrower = wave | StereoWidener(width=0.5) >>> # Convert to mono >>> mono = wave | StereoWidener(width=0.0) """ def __init__(self, width: float = 1.5): super().__init__() if width < 0: raise ValueError("Width must be non-negative") self.width = width def forward(self, x: Tensor) -> Tensor: """Apply stereo widening. Only processes stereo (2-channel) audio. Non-stereo audio is returned unchanged. """ # Only works on stereo audio if x.ndim < 2 or x.shape[0] != 2: return x # Return unchanged for non-stereo # Extract left and right channels left = x[0] right = x[1] # Convert to Mid/Side mid = (left + right) / 2 # Sum of L+R (mono content) side = (left - right) / 2 # Difference L-R (stereo width) # Widen by scaling Side component side_widened = side * self.width # Convert back to L/R new_left = mid + side_widened new_right = mid - side_widened return torch.stack([new_left, new_right]) # Usage examples wave = Wave.from_file("stereo_mix.wav") # Widen stereo image widener = StereoWidener(width=1.5) wider = wave | widener wider.save("widened_mix.wav") # Narrow stereo image narrower = StereoWidener(width=0.5) narrow = wave | narrower narrow.save("narrowed_mix.wav") ``` ```{mermaid} graph TB Input[Stereo Input
Left, Right] subgraph "Mid/Side Conversion" Mid["Mid = (L + R) / 2
(Mono content)"] Side["Side = (L - R) / 2
(Stereo width)"] end subgraph "Stereo Widening" Scale["Side × width
(Adjust stereo width)"] end subgraph "L/R Conversion" NewL["New Left = Mid + Side×width"] NewR["New Right = Mid - Side×width"] end Output[Stereo Output
Wider/Narrower] Input --> Mid Input --> Side Side --> Scale Mid --> NewL Mid --> NewR Scale --> NewL Scale --> NewR NewL --> Output NewR --> Output style Input fill:#e1f5ff style Output fill:#e1f5ff style Scale fill:#fff5e1 ``` ## Built-in Multi-Channel Strategies TorchFX provides several built-in strategies for multi-channel processing. ### Per-Channel Normalization The {class}`~torchfx.effect.PerChannelNormalizationStrategy` normalizes each channel independently to its own peak value: ```python import torch import torchfx as fx from torchfx.effect import Normalize, PerChannelNormalizationStrategy # Create stereo with imbalanced channels left_loud = torch.randn(44100) * 0.8 # Peak ~0.8 right_quiet = torch.randn(44100) * 0.3 # Peak ~0.3 stereo = torch.stack([left_loud, right_quiet]) wave = fx.Wave(stereo, fs=44100) # Global normalization: both channels scaled by same factor global_norm = wave | Normalize(peak=1.0) # Result: left ~1.0, right ~0.375 (preserves balance) # Per-channel normalization: each channel scaled independently strategy = PerChannelNormalizationStrategy() perchannel_norm = wave | Normalize(peak=1.0, strategy=strategy) # Result: left ~1.0, right ~1.0 (changes balance) ``` **Comparison**: ```{mermaid} graph TB Input["Stereo Input
Left peak: 0.8
Right peak: 0.3"] subgraph Global[Global Normalization] GMax["Find global max: 0.8"] GScale["Scale both by: 1.0/0.8 = 1.25"] GResult["Left: 1.0, Right: 0.375"] end subgraph PerChannel[Per-Channel Normalization] PCMax1["Left max: 0.8"] PCMax2["Right max: 0.3"] PCScale["Scale left: 1.0/0.8
Scale right: 1.0/0.3"] PCResult["Left: 1.0, Right: 1.0"] end Input --> Global Input --> PerChannel GMax --> GScale --> GResult PCMax1 --> PCScale PCMax2 --> PCScale PCScale --> PCResult style Input fill:#e1f5ff style Global fill:#fff5e1 style PerChannel fill:#e8f5e1 ``` **Implementation details**: - For 2D tensors `(channels, time)`: computes max per channel along `dim=1` with `keepdim=True` - For 3D tensors `(batch, channels, time)`: computes max per channel along `dim=2` with `keepdim=True` - Uses `keepdim=True` to maintain broadcasting compatibility ### Mono Delay Strategy {class}`~torchfx.effect.MonoDelayStrategy` applies identical delay processing to all channels independently: ```python import torchfx as fx from torchfx.effect import Delay, MonoDelayStrategy wave = fx.Wave.from_file("stereo.wav") # Mono delay: same delay pattern on both channels delay = Delay( bpm=120, delay_time="1/4", feedback=0.4, mix=0.3, strategy=MonoDelayStrategy() # Default, can be omitted ) delayed = wave | delay ``` ## Dimension-Agnostic Processing For effects that should work with any tensor shape, implement dimension detection and handling: ```python from torch import Tensor from torchfx import FX class DimensionAgnosticEffect(FX): """Effect that handles 1D, 2D, and 3D+ tensors. Automatically detects tensor dimensionality and routes to appropriate processing method. Examples -------- >>> # Works with mono >>> mono = torch.randn(44100) >>> fx = DimensionAgnosticEffect() >>> result = fx(mono) >>> # Works with stereo >>> stereo = torch.randn(2, 44100) >>> result = fx(stereo) >>> # Works with batched multi-channel >>> batch = torch.randn(8, 2, 44100) >>> result = fx(batch) """ def forward(self, x: Tensor) -> Tensor: """Process tensor with automatic dimension detection.""" if x.ndim == 1: # Mono: (time,) return self._process_mono(x) elif x.ndim == 2: # Multi-channel: (channels, time) return self._process_multi_channel(x) elif x.ndim == 3: # Batched: (batch, channels, time) return self._process_batched(x) else: # Higher dimensions: flatten, process, reshape original_shape = x.shape flattened = x.view(-1, x.size(-1)) # Flatten to (N, time) processed = self._process_multi_channel(flattened) return processed.view(original_shape) def _process_mono(self, x: Tensor) -> Tensor: """Process single channel.""" # Example: reduce volume by 50% return x * 0.5 def _process_multi_channel(self, x: Tensor) -> Tensor: """Process each channel independently.""" for i in range(x.shape[0]): x[i] = self._process_mono(x[i]) return x def _process_batched(self, x: Tensor) -> Tensor: """Process batched data.""" for b in range(x.shape[0]): x[b] = self._process_multi_channel(x[b]) return x ``` ```{mermaid} graph TD Forward["forward(waveform)"] CheckDim{"waveform.ndim?"} Mono["ndim == 1
(time,)
Process as mono"] MultiCh["ndim == 2
(channels, time)
Process per channel"] Batched["ndim == 3
(batch, channels, time)
Process with batch dim"] Higher["ndim > 3
(..., time)
Flatten → process → reshape"] Forward --> CheckDim CheckDim -->|"1"| Mono CheckDim -->|"2"| MultiCh CheckDim -->|"3"| Batched CheckDim -->|">3"| Higher style CheckDim fill:#fff5e1 style Mono fill:#e1f5ff style MultiCh fill:#e8f5e1 style Batched fill:#f5e1ff style Higher fill:#ffe1e1 ``` ## Surround Sound Configurations TorchFX supports standard surround sound channel configurations. ### 5.1 Surround Sound 5.1 surround has 6 channels with standard ordering: ```python import torch from torchfx import Wave # 5.1 channel order: FL, FR, FC, LFE, BL, BR # FL = Front Left, FR = Front Right, FC = Front Center # LFE = Low Frequency Effects (subwoofer) # BL = Back Left, BR = Back Right surround_51 = torch.randn(6, 44100) # 6 channels, 1 second wave_51 = Wave(surround_51, fs=44100) # Access individual channels front_left = surround_51[0] front_right = surround_51[1] center = surround_51[2] lfe = surround_51[3] back_left = surround_51[4] back_right = surround_51[5] ``` ### 7.1 Surround Sound 7.1 surround has 8 channels: ```python # 7.1 channel order: FL, FR, FC, LFE, BL, BR, SL, SR # SL = Side Left, SR = Side Right surround_71 = torch.randn(8, 44100) # 8 channels wave_71 = Wave(surround_71, fs=44100) ``` ### Custom Surround Effect Process surround sound with channel-specific effects: ```python from torch import Tensor, nn from torchfx import FX, Wave from torchfx.filter import iir import torchaudio.transforms as T class SurroundProcessor(FX): """Process 5.1 surround sound with channel-specific effects. - Front channels: Full range processing - Center: Voice-optimized (bandpass) - LFE: Low-pass only (subwoofer) - Rear channels: Ambient processing Parameters ---------- fs : int Sample rate in Hz """ def __init__(self, fs: int): super().__init__() self.fs = fs # Channel order: FL, FR, FC, LFE, BL, BR self.channels = nn.ModuleList([ self.front_lr(), # 0: Front Left self.front_lr(), # 1: Front Right self.center(), # 2: Front Center self.lfe(), # 3: LFE (subwoofer) self.rear(), # 4: Back Left self.rear(), # 5: Back Right ]) def front_lr(self) -> nn.Sequential: """Full-range processing for front L/R.""" return nn.Sequential( iir.HiButterworth(cutoff=80, order=2, fs=self.fs), iir.LoButterworth(cutoff=18000, order=4, fs=self.fs), ) def center(self) -> nn.Sequential: """Voice-optimized for center channel.""" return nn.Sequential( iir.HiButterworth(cutoff=200, order=2, fs=self.fs), # Remove rumble iir.LoButterworth(cutoff=8000, order=4, fs=self.fs), # Voice range ) def lfe(self) -> nn.Sequential: """Subwoofer channel (low-pass only).""" return nn.Sequential( iir.LoButterworth(cutoff=120, order=8, fs=self.fs), # Sharp LPF ) def rear(self) -> nn.Sequential: """Ambient processing for rear channels.""" return nn.Sequential( iir.HiButterworth(cutoff=100, order=2, fs=self.fs), iir.LoButterworth(cutoff=12000, order=4, fs=self.fs), T.Vol(0.8), # Slightly quieter for ambience ) def forward(self, x: Tensor) -> Tensor: """Process 5.1 surround audio.""" if x.shape[0] != 6: raise ValueError(f"Expected 6 channels for 5.1, got {x.shape[0]}") for i in range(6): x[i] = self.channels[i](x[i]) return x # Usage surround_wave = Wave.from_file("movie_51.wav") processor = SurroundProcessor(fs=surround_wave.fs) processed = surround_wave | processor processed.save("processed_51.wav") ``` ## Integration with Wave Pipeline The {class}`~torchfx.Wave` class automatically configures effects with the `fs` (sample rate) parameter when using the pipeline operator (`|`). ### Automatic Sample Rate Configuration ```python # Create effect without fs fx = ComplexEffect(num_channels=2, fs=None) # Wave automatically sets fs via __update_config wave = Wave.from_file("stereo.wav") result = wave | fx # fx.fs is automatically set to wave.fs ``` The {class}`~torchfx.Wave` class calls `__update_config` on effects that have an `fs` attribute, automatically setting the sample rate before processing. ```{seealso} {doc}`../core-concepts/wave` - Wave class architecture and automatic configuration ``` ## Best Practices ### Use ModuleList for Channel Chains ```python # ✅ GOOD: Proper module registration self.channels = nn.ModuleList([ self.create_chain(0), self.create_chain(1), ]) # ❌ BAD: Regular list won't register modules self.channels = [ self.create_chain(0), self.create_chain(1), ] ``` Using {class}`torch.nn.ModuleList` ensures: - Proper parameter registration for GPU transfer - Correct gradient tracking for trainable parameters - Integration with PyTorch's module system ### Validate Sample Rate ```python # ✅ GOOD: Clear error messages def forward(self, x: Tensor) -> Tensor: if self.fs is None: raise ValueError("Sample rate (fs) must be set before processing") # Process audio... # ❌ BAD: No validation, may fail later def forward(self, x: Tensor) -> Tensor: # self.fs might be None! return self.filter(x) ``` ### Handle Variable Channel Counts ```python # ✅ GOOD: Flexible channel handling def forward(self, x: Tensor) -> Tensor: num_channels = x.shape[0] if x.ndim >= 2 else 1 if self.channels is None or len(self.channels) != num_channels: self._create_channels(num_channels) # Process channels... # ❌ BAD: Hardcoded channel count def forward(self, x: Tensor) -> Tensor: x[0] = self.process_left(x[0]) x[1] = self.process_right(x[1]) # Fails for mono or surround ``` ### Preserve Tensor Properties ```python # ✅ GOOD: Preserve device and dtype output = torch.zeros_like(input_tensor) # ❌ BAD: May create tensor on wrong device output = torch.zeros(input_tensor.shape) ``` ### Use keepdim for Broadcasting ```python # ✅ CORRECT: keepdim=True allows broadcasting max_per_channel = torch.max(torch.abs(waveform), dim=1, keepdim=True).values normalized = waveform / max_per_channel * peak # ❌ WRONG: Shape mismatch max_per_channel = torch.max(torch.abs(waveform), dim=1).values normalized = waveform / max_per_channel * peak # Broadcasting error! ``` ## Common Pitfalls ### Pitfall 1: In-Place Modifications ```python # ❌ WRONG: In-place modification can cause issues def forward(self, x: Tensor) -> Tensor: for i in range(x.shape[0]): x[i] = self.process(x[i]) # Modifies input return x # ✅ CORRECT: Create output tensor def forward(self, x: Tensor) -> Tensor: output = torch.zeros_like(x) for i in range(x.shape[0]): output[i] = self.process(x[i]) return output ``` ### Pitfall 2: Incorrect Channel Validation ```python # ❌ WRONG: Doesn't handle batched input if x.shape[0] != 2: raise ValueError("Expected stereo") # ✅ CORRECT: Check channel dimension properly if x.ndim < 2 or x.shape[-2] != 2: raise ValueError(f"Expected stereo, got shape {x.shape}") ``` ### Pitfall 3: Forgetting Output Length ```python # ❌ WRONG: May truncate delay tails def forward(self, x: Tensor) -> Tensor: output = torch.zeros_like(x) # Same length as input # Process with delay... output might be too short! # ✅ CORRECT: Account for extended output def forward(self, x: Tensor) -> Tensor: output_length = x.size(-1) + self.delay_samples output = torch.zeros(x.size(0), output_length, ...) # Process with proper length ``` ## PyTorch Integration Multi-channel effects work seamlessly with PyTorch's data pipeline. ### Using with DataLoader ```python import torch from torch.utils.data import Dataset, DataLoader from torchfx import Wave class AudioDataset(Dataset): """Dataset with multi-channel audio augmentation.""" def __init__(self, file_paths, transform=None): self.file_paths = file_paths self.transform = transform def __len__(self): return len(self.file_paths) def __getitem__(self, idx): # Load audio wave = Wave.from_file(self.file_paths[idx]) # Apply multi-channel transform if self.transform: wave = wave | self.transform return wave.ys, wave.fs # Create dataset with multi-channel effect dataset = AudioDataset( file_paths=["song1.wav", "song2.wav", "song3.wav"], transform=ComplexEffect(num_channels=2, fs=44100) ) # Use with DataLoader dataloader = DataLoader(dataset, batch_size=4, shuffle=True) for batch_audio, batch_fs in dataloader: # batch_audio shape: (batch, channels, time) print(f"Batch shape: {batch_audio.shape}") ``` ### GPU Acceleration Multi-channel effects automatically support GPU: ```python import torch from torchfx import Wave wave = Wave.from_file("stereo.wav") # Move to GPU if torch.cuda.is_available(): wave = wave.to("cuda") # Effect runs on GPU fx = ComplexEffect(num_channels=2, fs=wave.fs).to("cuda") processed = wave | fx # Move back to CPU for saving processed.to("cpu").save("output.wav") ``` ```{seealso} {doc}`gpu-acceleration` - Complete guide to GPU acceleration in TorchFX ``` ## Summary Multi-channel processing in TorchFX follows these key principles: | Principle | Implementation | |-----------|----------------| | **Tensor conventions** | Last dimension is time, earlier dimensions are channels/batches | | **Default behavior** | Broadcast operations unless channel-specific logic is implemented | | **Strategy pattern** | Use strategy classes for pluggable channel behavior | | **ModuleList** | Use {class}`torch.nn.ModuleList` for per-channel processing chains | | **Dimension handling** | Detect and handle 1D, 2D, 3D+ tensors appropriately | | **Cross-channel effects** | Explicitly access channels via indexing for interaction patterns | | **Sample rate** | Store `fs` attribute for automatic configuration via {class}`~torchfx.Wave` pipeline | **Architecture Overview**: ```{mermaid} graph TB subgraph "Multi-Channel Architecture" Input[Input Audio
channels, time] subgraph FX[FX Base Class] Init[__init__
Set fs, num_channels] Chains[nn.ModuleList
Per-channel chains] Forward[forward
Process channels] end subgraph Strategies Broadcast[Broadcast Strategy
Same to all] PerChannel[Per-Channel Strategy
Independent] CrossChannel[Cross-Channel Strategy
Interactive] end Output[Output Audio
channels, time] end Input --> FX Init --> Chains Chains --> Forward Forward --> Strategies Strategies --> Output style Input fill:#e1f5ff style Output fill:#e1f5ff style FX fill:#fff5e1 style Strategies fill:#e8f5e1 ``` ## Related Concepts - {doc}`../core-concepts/wave` - Wave class and tensor handling - {doc}`../core-concepts/fx` - FX base class for effects - {doc}`../tutorials/custom-effects` - Creating custom effects - {doc}`gpu-acceleration` - GPU acceleration for multi-channel audio - {doc}`../tutorials/ml-batch-processing` - Batch processing for ML workflows ## External Resources - [PyTorch Audio Documentation](https://pytorch.org/audio/stable/index.html) - torchaudio tensor conventions - [Surround Sound on Wikipedia](https://en.wikipedia.org/wiki/Surround_sound) - Multi-channel audio formats - [Mid/Side Processing](https://en.wikipedia.org/wiki/Stereophonic_sound#M/S_technique:_mid/side_stereophony) - Stereo imaging technique - [ITU-R BS.775](https://www.itu.int/rec/R-REC-BS.775/) - Multi-channel audio standard ## References ```{bibliography} :filter: docname in docnames :style: alpha ```