Wave - Digital Audio Representation#
The Wave class is the foundation of TorchFX’s audio processing system. It wraps PyTorch tensors with audio-specific metadata and methods, making it easy to work with digital audio signals.
What is a Wave?#
In digital audio processing, a signal is represented as a discrete sequence of samples. The Wave class encapsulates:
Audio Data (tensor): A 2D PyTorch tensor with shape
(channels, samples)Sample Rate (sampling frequency or
fs): Number of samples per second (e.g., 44100 Hz)Metadata: Optional information about the audio (encoding, bit depth, etc.)
Device: Where the audio data lives (CPU or CUDA GPU)
classDiagram
class Wave {
+Tensor ys
+int fs
+dict metadata
-Device __device
+from_file(path) Wave$
+save(path) None
+to(device) Wave
+transform(func) Wave
+__or__(fx) Wave
+channels() int
+duration(unit) float
+get_channel(index) Wave
+merge(waves) Wave$
}
class Tensor {
<<PyTorch>>
Shape: (channels, samples)
}
Wave --> Tensor : wraps
note for Wave "Container for digital audio\nwith sample rate and metadata"
note for Tensor "2D tensor:\n- Dim 0: channels\n- Dim 1: time (samples)"
Creating Waves#
From Audio Files#
The most common way to create a Wave is to load it from an audio file:
import torchfx as fx
# Load from file (sample rate inferred from file)
wave = fx.Wave.from_file("audio.wav")
print(f"Sample rate: {wave.fs} Hz")
print(f"Channels: {wave.channels()}")
print(f"Duration: {wave.duration('sec')} seconds")
print(f"Samples: {len(wave)}")
print(f"Metadata: {wave.metadata}")
Supported formats: WAV, FLAC, MP3, OGG (depends on torchaudio backend)
From NumPy/PyTorch Arrays#
You can create a Wave from existing array data:
import torch
import numpy as np
import torchfx as fx
# From PyTorch tensor
stereo_data = torch.randn(2, 44100) # 2 channels, 1 second at 44.1kHz
wave = fx.Wave(stereo_data, fs=44100)
# From NumPy array
mono_data = np.random.randn(1, 22050) # 1 channel, 0.5 seconds at 44.1kHz
wave = fx.Wave(mono_data, fs=44100)
Shape requirement: Audio data must be 2D with shape (channels, samples). For mono audio, use shape (1, samples).
Synthetic Signals#
Generate test signals for development and testing:
import torch
import torchfx as fx
# Generate a sine wave
fs = 44100
duration = 1.0 # seconds
frequency = 440 # Hz (A4 note)
t = torch.linspace(0, duration, int(fs * duration))
sine = torch.sin(2 * torch.pi * frequency * t).unsqueeze(0) # Add channel dimension
wave = fx.Wave(sine, fs=fs)
Working with Waves#
Device Management#
Wave objects can be moved between CPU and GPU for accelerated processing:
import torchfx as fx
wave = fx.Wave.from_file("audio.wav")
# Move to GPU
wave.to("cuda")
# Or use the device property
wave.device = "cuda"
# Check current device
print(wave.device) # "cuda"
# Move back to CPU
wave.to("cpu")
This is useful for batch processing large datasets or real-time effects that benefit from GPU acceleration.
See also
PyTorch CUDA Semantics - Understanding device management in PyTorch
Channel Operations#
Work with individual channels or multi-channel audio:
import torchfx as fx
wave = fx.Wave.from_file("stereo.wav")
# Get number of channels
num_channels = wave.channels() # 2
# Extract a specific channel (returns new Wave)
left = wave.get_channel(0)
right = wave.get_channel(1)
# Process individual channels
processed_left = left | SomeEffect()
Duration and Length#
Get audio duration in different units:
import torchfx as fx
wave = fx.Wave.from_file("audio.wav")
# Duration in seconds
duration_sec = wave.duration("sec") # 3.5
# Duration in milliseconds
duration_ms = wave.duration("ms") # 3500.0
# Length in samples
num_samples = len(wave) # 154350 (at 44.1kHz)
Merging Waves#
Combine multiple waves into one:
import torchfx as fx
wave1 = fx.Wave.from_file("audio1.wav")
wave2 = fx.Wave.from_file("audio2.wav")
# Mix waves (sum channels)
mixed = fx.Wave.merge([wave1, wave2], split_channels=False)
# Keep channels separate (stack as new channels)
stacked = fx.Wave.merge([wave1, wave2], split_channels=True)
Note: All waves must have the same sample rate. Shorter waves are zero-padded when mixing.
Transformations#
Functional Transformations#
Apply any function to the underlying tensor:
import torch
import torchfx as fx
wave = fx.Wave.from_file("audio.wav")
# Apply FFT
spectrum = wave.transform(torch.fft.fft)
# Apply custom function
def amplify(x):
return x * 2.0
louder = wave.transform(amplify)
# Chain transformations
processed = wave.transform(torch.fft.fft).transform(torch.abs)
The transform() method returns a new Wave with the same sample rate and metadata.
Pipeline Operator#
The recommended way to apply effects is using the pipeline operator (|):
import torchfx as fx
from torchfx.filter import iir
wave = fx.Wave.from_file("audio.wav")
# Apply single effect
filtered = wave | iir.LoButterworth(cutoff=1000)
# Chain multiple effects
processed = wave | iir.HiButterworth(100) | fx.effect.Reverb() | fx.effect.Normalize()
The pipeline operator:
Automatically configures effects with the wave’s sample rate
Returns a new
Wave(immutable operations)Supports any
FXsubclass ortorch.nn.Module
See also
Pipeline Operator - Functional Composition - Deep dive into the pipeline operator
Saving Waves#
Basic Saving#
Save processed audio back to disk:
import torchfx as fx
wave = fx.Wave.from_file("input.wav")
processed = wave | SomeEffect()
# Save with default settings (format inferred from extension)
processed.save("output.wav")
Advanced Saving Options#
Control encoding and bit depth:
import torchfx as fx
wave = fx.Wave.from_file("audio.wav")
# Save as 24-bit FLAC
wave.save("output.flac", encoding="PCM_S", bits_per_sample=24)
# Save as 32-bit float WAV
wave.save("output.wav", encoding="PCM_F", bits_per_sample=32)
# Force specific format (overrides extension)
wave.save("output.ogg", format="wav")
Encoding options:
PCM_S: Signed integer PCM (most common)PCM_U: Unsigned integer PCMPCM_F: Floating-point PCM (32-bit or 64-bit)ULAW,ALAW: Compressed formats
Bit depth options: 8, 16, 24, 32, 64 (depending on encoding)
Note
The save() method automatically creates parent directories and moves data to CPU before saving.
Implementation Details#
Internal Representation#
sequenceDiagram
participant User
participant Wave
participant Tensor
participant Device
User->>Wave: from_file("audio.wav")
Wave->>Tensor: Load audio data
Tensor-->>Wave: shape (channels, samples)
Wave->>Device: Set device="cpu"
Wave-->>User: Return Wave object
User->>Wave: wave.to("cuda")
Wave->>Tensor: tensor.to("cuda")
Wave->>Device: Update __device="cuda"
Wave-->>User: Return self
User->>Wave: wave | Effect()
Wave->>Wave: __or__(effect)
Wave->>Wave: __update_config(effect)
Wave->>Wave: transform(effect.forward)
Wave-->>User: Return new Wave
Sample Rate Configuration#
When using the pipeline operator, Wave automatically configures effects:
Checks if the effect is an
FXinstanceSets the effect’s
fsattribute if it’sNoneFor filters, calls
compute_coefficients()if not yet computedApplies the effect’s
forward()method to the audio tensor
This ensures effects always have the correct sample rate without manual configuration.
import torchfx as fx
from torchfx.filter import iir
wave = fx.Wave.from_file("audio.wav") # fs = 44100
# Filter automatically configured with fs=44100
filtered = wave | iir.LoButterworth(cutoff=1000)
Mathematical Representation#
A discrete-time signal in TorchFX is represented as:
where:
\(\mathbf{x}[n]\) is the multi-channel signal at sample index \(n\)
\(x_c[n]\) is the signal for channel \(c\) at sample \(n\)
\(C\) is the number of channels
\(n \in [0, N-1]\) where \(N\) is the total number of samples
The sample rate \(f_s\) determines the relationship between sample index and time:
Best Practices#
Memory Management#
import torchfx as fx
# ✅ GOOD: Reuse wave object with pipeline
wave = fx.Wave.from_file("audio.wav")
processed = wave | Effect1() | Effect2() | Effect3()
# ❌ BAD: Creating multiple intermediate wave objects
wave = fx.Wave.from_file("audio.wav")
wave1 = wave | Effect1()
wave2 = wave1 | Effect2()
wave3 = wave2 | Effect3()
Batch Processing#
import torchfx as fx
from pathlib import Path
# Process multiple files
input_dir = Path("audio_files")
output_dir = Path("processed")
effect_chain = Effect1() | Effect2()
for audio_file in input_dir.glob("*.wav"):
wave = fx.Wave.from_file(audio_file)
processed = wave | effect_chain
processed.save(output_dir / audio_file.name)
GPU Acceleration#
import torchfx as fx
import torch
# Check GPU availability
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
wave = fx.Wave.from_file("audio.wav").to(device)
processed = wave | HeavyEffect()
processed.to("cpu").save("output.wav") # Move back to CPU for saving
Common Pitfalls#
Incorrect Tensor Shape#
import torch
import torchfx as fx
# ❌ WRONG: 1D tensor
mono = torch.randn(44100)
wave = fx.Wave(mono, fs=44100) # Error!
# ✅ CORRECT: 2D tensor with channel dimension
mono = torch.randn(1, 44100)
wave = fx.Wave(mono, fs=44100)
Sample Rate Mismatch#
import torchfx as fx
wave1 = fx.Wave.from_file("audio_44k.wav") # fs = 44100
wave2 = fx.Wave.from_file("audio_48k.wav") # fs = 48000
# ❌ WRONG: Cannot merge waves with different sample rates
mixed = fx.Wave.merge([wave1, wave2]) # ValueError!
# ✅ CORRECT: Resample first (using external library)
import torchaudio.transforms as T
resampler = T.Resample(orig_freq=48000, new_freq=44100)
wave2_resampled = fx.Wave(resampler(wave2.ys), fs=44100)
mixed = fx.Wave.merge([wave1, wave2_resampled])
External Resources#
Digital Signal on Wikipedia - Understanding discrete-time signals
Sample Rate on Wikipedia - How sampling works
PyTorch Tensor Documentation - Working with tensors
torchaudio I/O Documentation - Loading and saving audio