# Multi-Channel Audio Processing
Learn how to create audio effects that process multiple channels independently or with channel interaction. This tutorial covers tensor shape conventions, per-channel processing patterns, and stereo-specific effects.
## Overview
Multi-channel audio is ubiquitous in modern audio production—from simple stereo (2 channels) to surround sound (5.1, 7.1) and beyond. TorchFX provides flexible patterns for handling multi-channel audio:
- **Independent processing**: Apply different effects to each channel
- **Broadcast processing**: Apply the same effect to all channels
- **Interactive processing**: Channels affect each other (e.g., ping-pong delay)
- **Channel-aware strategies**: Normalization, delay, and custom algorithms
```{mermaid}
graph TB
subgraph "Multi-Channel Processing Patterns"
Input[Multi-Channel
Input]
subgraph Independent[Independent Processing]
Ch1[Channel 1:
Filter A]
Ch2[Channel 2:
Filter B]
ChN[Channel N:
Filter C]
end
subgraph Broadcast[Broadcast Processing]
Same[Same Filter
All Channels]
end
subgraph Interactive[Interactive Processing]
PingPong[Ping-Pong
L→R, R→L]
end
Input --> Independent
Input --> Broadcast
Input --> Interactive
end
style Input fill:#e1f5ff
style Independent fill:#fff5e1
style Broadcast fill:#e8f5e1
style Interactive fill:#f5e1ff
```
## Tensor Shape Conventions
TorchFX follows PyTorch audio conventions:
| Shape | Description | Use Case |
|-------|-------------|----------|
| `(time,)` | Mono audio | Single microphone recording |
| `(channels, time)` | Multi-channel | Stereo, surround sound |
| `(batch, channels, time)` | Batched multi-channel | ML training batches |
| `(..., time)` | Arbitrary dimensions | General tensor processing |
**Key principle**: The **last dimension is always time**, earlier dimensions are channels/batches.
```python
import torch
import torchfx as fx
# Mono audio: (time,)
mono = torch.randn(44100) # 1 second at 44.1kHz
wave_mono = fx.Wave(mono.unsqueeze(0), fs=44100) # Add channel dimension → (1, 44100)
# Stereo audio: (channels, time)
stereo = torch.randn(2, 44100) # 2 channels, 1 second
wave_stereo = fx.Wave(stereo, fs=44100)
# 5.1 surround: (channels, time)
surround = torch.randn(6, 44100) # 6 channels
wave_surround = fx.Wave(surround, fs=44100)
# Batched stereo: (batch, channels, time)
batch = torch.randn(8, 2, 44100) # 8 samples of stereo audio
```
```{seealso}
{doc}`/guides/core-concepts/wave` - Wave class and tensor handling
```
## Built-in Multi-Channel Strategies
### Per-Channel Normalization
The {class}`~torchfx.effect.PerChannelNormalizationStrategy` normalizes each channel independently to its own peak:
```python
import torchfx as fx
from torchfx.effect import Normalize, PerChannelNormalizationStrategy
# Load stereo audio
wave = fx.Wave.from_file("stereo.wav") # (2, time)
# Standard normalization (uses global peak across all channels)
global_norm = wave | Normalize(peak=1.0)
# Per-channel normalization (each channel normalized to its own peak)
strategy = PerChannelNormalizationStrategy()
perchannel_norm = wave | Normalize(peak=1.0, strategy=strategy)
```
**Behavior comparison**:
```python
# Example with imbalanced channels
left_loud = torch.randn(44100) * 0.8 # Peak ~0.8
right_quiet = torch.randn(44100) * 0.3 # Peak ~0.3
stereo = torch.stack([left_loud, right_quiet])
wave = fx.Wave(stereo, fs=44100)
# Global normalization: both scaled by same factor (based on loudest channel)
global_norm = wave | Normalize(peak=1.0)
# Result: left ~1.0, right ~0.375
# Per-channel normalization: each scaled independently
perchannel_norm = wave | Normalize(peak=1.0, strategy=PerChannelNormalizationStrategy())
# Result: left ~1.0, right ~1.0
```
### Delay Strategies
The {class}`~torchfx.effect.Delay` effect supports two multi-channel strategies:
#### MonoDelayStrategy (Default)
Applies the same delay to all channels independently:
```python
import torchfx as fx
from torchfx.effect import Delay, MonoDelayStrategy
wave = fx.Wave.from_file("stereo.wav")
# Mono strategy: identical delay on both channels
delay = Delay(
bpm=120,
delay_time="1/4",
feedback=0.4,
mix=0.3,
strategy=MonoDelayStrategy() # Default, can be omitted
)
delayed = wave | delay
```
#### PingPongDelayStrategy
Creates alternating delays between left and right stereo channels:
```python
import torchfx as fx
from torchfx.effect import Delay, PingPongDelayStrategy
wave = fx.Wave.from_file("stereo.wav") # Must be stereo (2 channels)
# Ping-pong delay: alternates between L→R and R→L
delay = Delay(
bpm=120,
delay_time="1/8",
feedback=0.5,
mix=0.4,
strategy=PingPongDelayStrategy()
)
delayed = wave | delay
```
**Ping-pong pattern**:
- **Tap 1**: Left channel → delays to → Right channel
- **Tap 2**: Right channel → delays to → Left channel
- **Tap 3**: Left channel → delays to → Right channel
- And so on...
```{mermaid}
sequenceDiagram
participant L as Left Channel
participant R as Right Channel
Note over L,R: Original Signal
L->>L: Original left
R->>R: Original right
Note over L,R: Tap 1 (100% amplitude)
L->>R: Left delays to Right
Note over L,R: Tap 2 (feedback^1)
R->>L: Right delays to Left
Note over L,R: Tap 3 (feedback^2)
L->>R: Left delays to Right
Note over L,R: Result: Ping-pong pattern
```
**Fallback**: If audio is not stereo, {class}`~torchfx.effect.PingPongDelayStrategy` automatically falls back to {class}`~torchfx.effect.MonoDelayStrategy`.
## Creating Per-Channel Effects
### Pattern 1: Independent Channel Processing
Use {class}`torch.nn.ModuleList` to store per-channel processing chains:
```python
import torch
from torch import Tensor, nn
from torchfx import FX, Wave
from torchfx.filter import iir
import torchaudio.transforms as T
class StereoProcessor(FX):
"""Apply different processing to left and right channels."""
def __init__(self, fs: int | None = None):
super().__init__()
self.fs = fs
# Store per-channel chains in ModuleList
self.channels = nn.ModuleList([
self.left_channel(),
self.right_channel(),
])
def left_channel(self) -> nn.Sequential:
"""Processing chain for left channel."""
return nn.Sequential(
iir.HiButterworth(cutoff=100, order=2, fs=self.fs), # Remove rumble
iir.LoButterworth(cutoff=8000, order=4, fs=self.fs), # Remove high freq
)
def right_channel(self) -> nn.Sequential:
"""Processing chain for right channel."""
return nn.Sequential(
iir.HiButterworth(cutoff=150, order=2, fs=self.fs), # Different HPF
iir.LoButterworth(cutoff=10000, order=4, fs=self.fs), # Different LPF
T.Vol(0.9), # Slight volume reduction
)
def forward(self, x: Tensor) -> Tensor:
"""Apply per-channel processing."""
if self.fs is None:
raise ValueError("Sample rate must be set")
# Process each channel with its own chain
for i in range(len(self.channels)):
x[i] = self.channels[i](x[i])
return x
# Usage
wave = Wave.from_file("stereo.wav")
processor = StereoProcessor(fs=wave.fs)
processed = wave | processor
processed.save("processed_stereo.wav")
```
**Key points**:
- Use {class}`torch.nn.ModuleList` to register submodules
- Pass `fs` to filters that need sample rate
- Process each channel independently in `forward()`
- Each channel can have completely different processing
```{mermaid}
graph TB
Input[Stereo Input
2, time]
subgraph Processor[StereoProcessor]
subgraph Left[Left Channel Chain]
LH[HiButterworth
100 Hz]
LL[LoButterworth
8000 Hz]
end
subgraph Right[Right Channel Chain]
RH[HiButterworth
150 Hz]
RL[LoButterworth
10000 Hz]
RV[Vol 0.9]
end
end
Output[Stereo Output
2, time]
Input -->|x[0]| Left
Input -->|x[1]| Right
LH --> LL
RH --> RL --> RV
Left -->|processed[0]| Output
Right -->|processed[1]| Output
style Input fill:#e1f5ff
style Output fill:#e1f5ff
style Left fill:#fff5e1
style Right fill:#ffe1e1
```
### Pattern 2: Dynamic Channel Count
Handle any number of channels dynamically:
```python
from torch import Tensor, nn
from torchfx import FX
class FlexibleMultiChannel(FX):
"""Effect that adapts to any number of channels."""
def __init__(self, fs: int | None = None):
super().__init__()
self.fs = fs
self.channels = None # Created dynamically
def _create_channels(self, num_channels: int):
"""Create processing chains for given number of channels."""
from torchfx.filter import iir
self.channels = nn.ModuleList([
nn.Sequential(
iir.HiButterworth(cutoff=100 * (i + 1), order=2, fs=self.fs),
iir.LoButterworth(cutoff=1000 * (i + 1), order=2, fs=self.fs),
)
for i in range(num_channels)
])
def forward(self, x: Tensor) -> Tensor:
num_channels = x.shape[0] if x.ndim >= 2 else 1
# Create channels on first forward pass
if self.channels is None:
self._create_channels(num_channels)
# Process each channel
if x.ndim >= 2:
for i in range(num_channels):
x[i] = self.channels[i](x[i])
else:
x = self.channels[0](x)
return x
```
### Pattern 3: Complete Example - ComplexEffect
Here's a complete, production-ready example adapted from the TorchFX examples:
```python
import torch
from torch import Tensor, nn
import torchaudio.transforms as T
from torchfx import FX, Wave
from torchfx.filter import iir
class ComplexEffect(FX):
"""Multi-channel effect with different processing per channel.
Channel 1: Bandpass 1000-2000 Hz
Channel 2: Bandpass 2000-4000 Hz with volume reduction
Parameters
----------
num_channels : int
Number of channels to process
fs : int, optional
Sample rate in Hz
Examples
--------
>>> wave = Wave.from_file("stereo.wav")
>>> fx = ComplexEffect(num_channels=2, fs=wave.fs)
>>> processed = wave | fx
"""
def __init__(self, num_channels: int, fs: int | None = None):
super().__init__()
self.num_channels = num_channels
self.fs = fs
# Per-channel processing chains
self.ch = nn.ModuleList([
self.channel1(),
self.channel2(),
])
def channel1(self) -> nn.Sequential:
"""Processing chain for channel 1."""
return nn.Sequential(
iir.HiButterworth(1000, fs=self.fs), # High-pass at 1000 Hz
iir.LoButterworth(2000, fs=self.fs), # Low-pass at 2000 Hz
)
def channel2(self) -> nn.Sequential:
"""Processing chain for channel 2."""
return nn.Sequential(
iir.HiButterworth(2000, fs=self.fs), # High-pass at 2000 Hz
iir.LoButterworth(4000, fs=self.fs), # Low-pass at 4000 Hz
T.Vol(0.5), # Reduce volume by 50%
)
def forward(self, x: Tensor) -> Tensor:
"""Apply per-channel processing."""
if self.fs is None:
raise ValueError("Sampling frequency (fs) must be set")
# Process each channel independently
for i in range(self.num_channels):
x[i] = self.ch[i](x[i])
return x
# Complete usage example
if __name__ == "__main__":
# Load stereo audio
wave = Wave.from_file("input.wav")
# Create and apply effect
fx = ComplexEffect(num_channels=2, fs=wave.fs)
result = wave | fx
# Save result
result.save("output.wav")
```
## Dimension-Agnostic Processing
For effects that should work with any tensor shape, detect and handle dimensions:
```python
from torch import Tensor
from torchfx import FX
class DimensionAgnosticEffect(FX):
"""Effect that handles 1D, 2D, and 3D+ tensors."""
def forward(self, x: Tensor) -> Tensor:
if x.ndim == 1:
# Mono: (time,)
return self._process_mono(x)
elif x.ndim == 2:
# Multi-channel: (channels, time)
return self._process_multi_channel(x)
elif x.ndim == 3:
# Batched: (batch, channels, time)
return self._process_batched(x)
else:
# Higher dimensions: flatten, process, reshape
original_shape = x.shape
flattened = x.view(-1, x.size(-1)) # Flatten to (N, time)
processed = self._process_multi_channel(flattened)
return processed.view(original_shape)
def _process_mono(self, x: Tensor) -> Tensor:
# Process single channel
return x * 0.5 # Example: reduce volume
def _process_multi_channel(self, x: Tensor) -> Tensor:
# Process each channel
for i in range(x.shape[0]):
x[i] = self._process_mono(x[i])
return x
def _process_batched(self, x: Tensor) -> Tensor:
# Process batched data
for b in range(x.shape[0]):
x[b] = self._process_multi_channel(x[b])
return x
```
## Channel Interaction Patterns
### Cross-Channel Effects
For effects where channels affect each other:
```python
from torch import Tensor
import torch
from torchfx import FX
class StereoWidener(FX):
"""Widen stereo image using Mid/Side processing."""
def __init__(self, width: float = 1.5):
"""
Parameters
----------
width : float
Stereo width multiplier (1.0 = no change, >1.0 = wider, <1.0 = narrower)
"""
super().__init__()
assert width >= 0, "Width must be non-negative"
self.width = width
def forward(self, x: Tensor) -> Tensor:
"""Apply stereo widening."""
# Only works on stereo (2-channel) audio
if x.ndim < 2 or x.shape[0] != 2:
return x # Return unchanged for non-stereo
left = x[0]
right = x[1]
# Convert to Mid/Side
mid = (left + right) / 2
side = (left - right) / 2
# Widen by scaling Side component
side_widened = side * self.width
# Convert back to L/R
new_left = mid + side_widened
new_right = mid - side_widened
return torch.stack([new_left, new_right])
# Usage
wave = Wave.from_file("stereo.wav")
widener = StereoWidener(width=1.5)
wider = wave | widener
```
### Channel-Aware Validation
Validate expected channel configuration:
```python
def forward(self, x: Tensor) -> Tensor:
# Require at least 2D tensor
if x.ndim < 2:
raise ValueError("Input must be at least 2D (channels, time)")
# Require stereo
if x.shape[-2] != 2:
raise ValueError(f"Expected stereo (2 channels), got {x.shape[-2]}")
# Process stereo audio
# ...
```
## Integration with PyTorch
### Using with DataLoader
Multi-channel effects work seamlessly in PyTorch data pipelines:
```python
import torch
from torch.utils.data import Dataset, DataLoader
from torchfx import Wave
class AudioDataset(Dataset):
"""Dataset with multi-channel audio augmentation."""
def __init__(self, file_paths, transform=None):
self.file_paths = file_paths
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
# Load audio
wave = Wave.from_file(self.file_paths[idx])
# Apply multi-channel transform
if self.transform:
wave = wave | self.transform
return wave.ys, wave.fs
# Create dataset with multi-channel effect
dataset = AudioDataset(
file_paths=["audio1.wav", "audio2.wav", "audio3.wav"],
transform=ComplexEffect(num_channels=2, fs=44100)
)
# Use with DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
for batch_audio, batch_fs in dataloader:
# batch_audio shape: (batch, channels, time)
print(f"Batch shape: {batch_audio.shape}")
```
### GPU Acceleration
Multi-channel effects automatically support GPU:
```python
import torch
from torchfx import Wave
wave = Wave.from_file("stereo.wav")
# Move to GPU
if torch.cuda.is_available():
wave = wave.to("cuda")
# Effect runs on GPU
fx = ComplexEffect(num_channels=2, fs=wave.fs)
processed = wave | fx
# Move back to CPU for saving
processed.to("cpu").save("output.wav")
```
```{seealso}
{doc}`/guides/advanced/gpu-acceleration` - GPU acceleration guide
```
## Best Practices
### Use ModuleList for Channel Chains
```python
# ✅ GOOD: Proper module registration
self.channels = nn.ModuleList([
self.create_chain(0),
self.create_chain(1),
])
# ❌ BAD: Regular list won't register modules
self.channels = [
self.create_chain(0),
self.create_chain(1),
]
```
### Handle Variable Channel Counts
```python
# ✅ GOOD: Flexible channel handling
def forward(self, x: Tensor) -> Tensor:
num_channels = x.shape[0] if x.ndim >= 2 else 1
if self.channels is None or len(self.channels) != num_channels:
self._create_channels(num_channels)
# Process channels
# ...
# ❌ BAD: Hardcoded channel count
def forward(self, x: Tensor) -> Tensor:
x[0] = self.process_left(x[0])
x[1] = self.process_right(x[1])
# Fails for mono or surround
```
### Preserve Tensor Properties
```python
# ✅ GOOD: Preserve device and dtype
output = torch.zeros_like(input_tensor)
# ❌ BAD: May create tensor on wrong device
output = torch.zeros(input_tensor.shape)
```
### Validate Input Shape
```python
# ✅ GOOD: Clear error messages
if x.ndim < 2:
raise ValueError(
f"Expected at least 2D tensor (channels, time), got shape {x.shape}"
)
if x.shape[0] != self.expected_channels:
raise ValueError(
f"Expected {self.expected_channels} channels, got {x.shape[0]}"
)
```
## Common Pitfalls
### In-Place Modifications
```python
# ❌ WRONG: Modifying input in-place can cause issues
def forward(self, x: Tensor) -> Tensor:
for i in range(x.shape[0]):
x[i] = self.process(x[i]) # In-place modification
return x
# ✅ CORRECT: Create output tensor
def forward(self, x: Tensor) -> Tensor:
output = torch.zeros_like(x)
for i in range(x.shape[0]):
output[i] = self.process(x[i])
return output
```
### Broadcasting Errors
```python
# ❌ WRONG: Shape mismatch
max_val = torch.max(torch.abs(x), dim=1).values # Shape: (channels,)
normalized = x / max_val * peak # Error: can't broadcast (channels,) to (channels, time)
# ✅ CORRECT: Use keepdim=True
max_val = torch.max(torch.abs(x), dim=1, keepdim=True).values # Shape: (channels, 1)
normalized = x / max_val * peak # Works: broadcasts correctly
```
### Forgetting Sample Rate
```python
# ❌ WRONG: No fs validation
def forward(self, x: Tensor) -> Tensor:
# self.fs might be None!
return self.filter(x)
# ✅ CORRECT: Validate fs
def forward(self, x: Tensor) -> Tensor:
if self.fs is None:
raise ValueError("Sample rate must be set before processing")
return self.filter(x)
```
## Related Concepts
- {doc}`/guides/core-concepts/wave` - Wave class and tensor handling
- {doc}`custom-effects` - Creating custom effects
- {doc}`/guides/advanced/pytorch-integration` - PyTorch integration patterns
- {doc}`/guides/advanced/gpu-acceleration` - GPU acceleration
## External Resources
- [PyTorch Audio Documentation](https://pytorch.org/audio/stable/index.html) - torchaudio tensor conventions
- [Multi-Channel Audio on Wikipedia](https://en.wikipedia.org/wiki/Surround_sound) - Multi-channel audio formats
- [Mid/Side Processing](https://en.wikipedia.org/wiki/Stereophonic_sound#M/S_technique:_mid/side_stereophony) - Stereo imaging technique
## References
```{bibliography}
:filter: docname in docnames
:style: alpha
```