Core Classes#
The core module provides the fundamental classes for working with audio in TorchFX.
Wave#
The Wave class handles audio data representation and I/O operations.
- class torchfx.Wave(ys, fs, device='cpu', metadata={})[source]#
Bases:
objectA discrete-time waveform representing multi-channel audio signals.
The Wave class is the fundamental data structure in torchfx, wrapping a PyTorch tensor containing audio samples along with its sampling frequency. It provides device management (CPU/CUDA), file I/O integration with torchaudio, functional transformations, and a pipeline operator for chaining effects and filters.
- Parameters:
- ys#
Audio signal tensor with shape (channels, samples). Each row represents a channel, and each column represents a time sample.
- Type:
Tensor
- fs#
Sampling frequency in Hz (samples per second). Used by filters and effects to compute time-domain parameters.
- Type:
- metadata#
Optional metadata dictionary containing audio file information such as num_frames, num_channels, bits_per_sample, and encoding.
- Type:
- device#
Current device location (“cpu” or “cuda”) where the tensor resides. Read-only property; use the setter or to() method to move devices.
- Type:
Device
Examples
Create a Wave from an array:
>>> import torch >>> from torchfx import Wave >>> samples = torch.randn(1, 44100) # 1 second of mono audio at 44.1kHz >>> wave = Wave(samples, fs=44100)
Load from an audio file:
>>> wave = Wave.from_file("audio.wav") >>> print(f"Channels: {wave.channels()}, Duration: {wave.duration('sec')}s")
Process on GPU with pipeline operator:
>>> from torchfx.filter import iir >>> wave = Wave.from_file("input.wav").to("cuda") >>> result = wave | iir.LoButterworth(1000, order=4)
Multi-channel processing:
>>> stereo = Wave.from_file("stereo.wav") >>> left = stereo.get_channel(0) >>> right = stereo.get_channel(1) >>> # Process channels independently >>> processed_left = left | some_effect >>> processed_right = right | other_effect >>> result = Wave.merge([processed_left, processed_right], split_channels=True)
Notes
The Wave class follows an immutability pattern for most operations. Methods like transform(), __or__(), and get_channel() return new Wave objects rather than modifying in place. Only device management methods (to(), device setter) modify the Wave in place while returning self for method chaining.
The pipe operator (|) automatically configures FX modules by setting the sampling frequency if not already set and computing filter coefficients before first use. This eliminates boilerplate and prevents common configuration errors.
See also
Methods
Load a Wave object from an audio file.
Save the wave to an audio file.
Move the Wave object to a specific device (CPU or CUDA).
Apply a functional transformation to the audio signal tensor.
Extract a specific channel as a new Wave object.
Combine multiple Wave objects into a single Wave.
- channels()[source]#
Return the number of audio channels in the wave.
- Returns:
Number of channels. Returns 1 for mono, 2 for stereo, or higher values for multi-channel audio.
- Return type:
Examples
Check channel count:
>>> wave = Wave.from_file("audio.wav") >>> print(f"This file has {wave.channels()} channel(s)")
Process based on channel count:
>>> wave = Wave.from_file("audio.wav") >>> if wave.channels() == 1: ... print("Mono audio") ... elif wave.channels() == 2: ... print("Stereo audio") ... else: ... print(f"Multi-channel audio with {wave.channels()} channels")
Extract all channels:
>>> wave = Wave.from_file("multichannel.wav") >>> channel_list = [wave.get_channel(i) for i in range(wave.channels())]
Notes
This method returns the first dimension of the ys tensor shape. The tensor follows the convention (channels, samples), so channels() returns ys.shape[0].
See also
get_channelExtract a specific channel as a new Wave object
mergeCombine multiple Wave objects with different merge strategies
- duration(unit)[source]#
Calculate the duration of the audio signal.
Computes the time length of the audio based on the number of samples and the sampling frequency. The duration can be returned in either seconds or milliseconds.
- Parameters:
unit ({"sec", "ms"}) – Unit for the returned duration. Use “sec” for seconds or “ms” for milliseconds.
- Returns:
Duration in the specified time unit. The return type is annotated as Second (float) when unit=”sec” or Millisecond (float) when unit=”ms”.
- Return type:
Examples
Get duration in seconds:
>>> wave = Wave.from_file("audio.wav") >>> duration_sec = wave.duration("sec") >>> print(f"Duration: {duration_sec:.2f} seconds")
Get duration in milliseconds:
>>> wave = Wave.from_file("audio.wav") >>> duration_ms = wave.duration("ms") >>> print(f"Duration: {duration_ms:.0f} ms")
Compare durations:
>>> wave1 = Wave.from_file("short.wav") >>> wave2 = Wave.from_file("long.wav") >>> if wave1.duration("sec") < wave2.duration("sec"): ... print("wave1 is shorter")
Calculate processing time estimate:
>>> wave = Wave.from_file("audio.wav") >>> duration = wave.duration("sec") >>> # Estimate processing time (example: 10x realtime) >>> estimated_time = duration * 10 >>> print(f"Estimated processing: {estimated_time:.2f} seconds")
Use in validation:
>>> wave = Wave.from_file("audio.wav") >>> max_duration_sec = 60.0 >>> if wave.duration("sec") > max_duration_sec: ... print("Audio file is too long")
Notes
The duration is calculated using the formula:
duration = (number_of_samples / sampling_frequency) * multiplier
Where: - number_of_samples = len(self) = self.ys.shape[1] - sampling_frequency = self.fs - multiplier = 1000 for milliseconds, 1 for seconds
For example, a Wave with 44100 samples at 44100 Hz has a duration of: - 1.0 second (44100 / 44100 * 1) - 1000.0 milliseconds (44100 / 44100 * 1000)
The duration is independent of the number of channels; it represents the time length of the audio signal.
See also
__len__Get the number of samples in the wave
channelsGet the number of channels
- classmethod from_file(path, *args, **kwargs)[source]#
Load a Wave object from an audio file.
This classmethod uses torchaudio.load to read audio files, automatically detecting the format and extracting metadata. Supported formats include WAV, MP3, FLAC, OGG, and others depending on the available torchaudio backend.
- Parameters:
path (str or Path) – Path to the audio file to load. Can be a string or pathlib.Path object.
*args – Additional positional arguments passed to torchaudio.load.
**kwargs – Additional keyword arguments passed to torchaudio.load. Common options include frame_offset, num_frames, normalize, channels_first, and format.
- Returns:
New Wave object containing the loaded audio data, sampling frequency, and extracted metadata.
- Return type:
Examples
Load a WAV file:
>>> wave = Wave.from_file("audio.wav") >>> print(f"Loaded {wave.channels()} channels at {wave.fs}Hz")
Load a specific portion of a file:
>>> # Load 1 second starting at 2 seconds >>> wave = Wave.from_file("long_audio.wav", frame_offset=88200, num_frames=44100)
Load and check metadata:
>>> wave = Wave.from_file("audio.flac") >>> print(wave.metadata) {'num_frames': 220500, 'num_channels': 2, 'bits_per_sample': 16, ...}
Load different formats:
>>> wav_wave = Wave.from_file("audio.wav") >>> mp3_wave = Wave.from_file("audio.mp3") >>> flac_wave = Wave.from_file("audio.flac")
Load with normalization:
>>> # Normalize to [-1, 1] range >>> wave = Wave.from_file("audio.wav", normalize=True)
Notes
The method automatically extracts metadata including num_frames, num_channels, bits_per_sample, and encoding when available. If metadata extraction fails, an empty metadata dictionary is used instead.
The loaded audio tensor will be on CPU by default. Use the to() method or device parameter in subsequent processing to move to GPU:
>>> wave = Wave.from_file("audio.wav").to("cuda")
Format support depends on the torchaudio backend (SoX or FFmpeg). Check torchaudio documentation for your installation’s supported formats.
See also
__init__Direct constructor for creating Wave from array data
saveSave a Wave object to an audio file
torchaudio.loadUnderlying function used for loading
- get_channel(index)[source]#
Extract a specific channel as a new Wave object.
This method creates a new Wave containing only the specified channel from the original multi-channel audio. The new Wave has the same sampling frequency and can be processed independently.
- Parameters:
index (int) – Zero-based index of the channel to extract. For stereo audio, use 0 for left channel and 1 for right channel.
- Returns:
New Wave object containing only the specified channel with shape (1, samples).
- Return type:
Examples
Extract left and right channels from stereo:
>>> stereo = Wave.from_file("stereo.wav") >>> left_channel = stereo.get_channel(0) >>> right_channel = stereo.get_channel(1)
Process channels independently:
>>> from torchfx.filter import iir >>> stereo = Wave.from_file("stereo.wav") >>> # Apply different filters to each channel >>> left = stereo.get_channel(0) | iir.LoButterworth(1000) >>> right = stereo.get_channel(1) | iir.HiButterworth(1000) >>> # Merge back to stereo >>> result = Wave.merge([left, right], split_channels=True)
Extract all channels as a list:
>>> wave = Wave.from_file("multichannel.wav") >>> channels = [wave.get_channel(i) for i in range(wave.channels())]
Process mono from stereo by averaging:
>>> stereo = Wave.from_file("stereo.wav") >>> left = stereo.get_channel(0) >>> right = stereo.get_channel(1) >>> # Mix to mono (this uses merge to sum channels) >>> mono = Wave.merge([left, right], split_channels=False)
Notes
The returned Wave object is independent of the original. Modifications to the returned Wave do not affect the original multi-channel Wave.
The index must be within the valid range [0, channels()-1]. Python’s standard indexing rules apply, so negative indices are supported (e.g., -1 for the last channel).
- classmethod merge(waves, split_channels=False)[source]#
Combine multiple Wave objects into a single Wave.
This classmethod provides two merging strategies: mixing (summing waveforms element-wise) or channel concatenation (preserving each wave as separate channels). All waves must have the same sampling frequency.
- Parameters:
- Returns:
New Wave object containing the merged audio with the same sampling frequency as the input waves.
- Return type:
- Raises:
ValueError – If no waves are provided or if waves have different sampling frequencies.
Examples
Mix two mono waves into one (sum strategy):
>>> wave1 = Wave.from_file("voice.wav") >>> wave2 = Wave.from_file("music.wav") >>> mixed = Wave.merge([wave1, wave2], split_channels=False) >>> # Result: mono wave with voice and music mixed together
Combine two mono waves into stereo (concatenate strategy):
>>> left = Wave.from_file("left.wav") >>> right = Wave.from_file("right.wav") >>> stereo = Wave.merge([left, right], split_channels=True) >>> # Result: stereo wave with left and right channels
Merge after independent processing:
>>> from torchfx.filter import iir >>> stereo = Wave.from_file("stereo.wav") >>> left = stereo.get_channel(0) | iir.LoButterworth(1000) >>> right = stereo.get_channel(1) | iir.HiButterworth(1000) >>> result = Wave.merge([left, right], split_channels=True)
Mix multiple sources with different effects:
>>> wave1 = Wave.from_file("track1.wav") | effect1 >>> wave2 = Wave.from_file("track2.wav") | effect2 >>> wave3 = Wave.from_file("track3.wav") | effect3 >>> final_mix = Wave.merge([wave1, wave2, wave3], split_channels=False)
Create multi-channel output from mono sources:
>>> channels = [Wave.from_file(f"channel_{i}.wav") for i in range(8)] >>> multichannel = Wave.merge(channels, split_channels=True) >>> print(f"Created {multichannel.channels()} channel audio")
Notes
Merge Strategy Comparison:
When split_channels=False (mixing): - Waves are summed element-wise - If waves have different lengths, shorter ones are zero-padded - Result has same number of channels as input waves - Use for: audio mixing, layering multiple sounds
When split_channels=True (concatenation): - Waves are concatenated along channel dimension - Each input wave becomes a separate channel in output - Result channel count = sum of all input channel counts - Use for: creating stereo/multichannel from mono sources
Length Handling:
When merging waves of different lengths: - The output length is the maximum length among all input waves - Shorter waves are zero-padded to match the longest wave
Validation:
All input waves must have identical sampling frequencies. This is enforced because merging waves with different sampling rates would be technically invalid.
Device Compatibility:
All waves should be on the same device. The merged Wave will be on the device of the first input wave.
See also
get_channelExtract individual channels from multi-channel audio
channelsGet the number of channels in a Wave
__init__Constructor for creating Wave from array data
- save(path, format=None, encoding=None, bits_per_sample=None, **kwargs)[source]#
Save the wave to an audio file.
- Parameters:
path (str or Path) – The path where to save the audio file.
format (str, optional) – Override the audio format. If not specified, the format is inferred from the file extension. Valid values include: “wav”, “flac”.
encoding (str, optional) – Changes the encoding for supported formats (wav, flac). Valid values: “PCM_S” (signed int), “PCM_U” (unsigned int), “PCM_F” (float), “ULAW”, “ALAW”.
bits_per_sample (int, optional) – Changes the bit depth for supported formats. Valid values: 8, 16, 24, 32, 64.
**kwargs – Additional keyword arguments to pass to torchaudio.save.
- Return type:
None
Examples
Save as WAV file:
>>> wave = Wave.from_file("input.wav") >>> wave.save("output.wav")
Save as FLAC with specific encoding:
>>> wave.save("output.flac", encoding="PCM_S", bits_per_sample=24)
Save with high bit depth:
>>> wave.save("output.wav", encoding="PCM_F", bits_per_sample=32)
Notes
The method automatically creates parent directories if they don’t exist.
The audio data is moved to CPU before saving.
Supported formats depend on the available torchaudio backend.
See also
from_fileLoad a wave from an audio file.
torchaudio.saveUnderlying function used for saving.
- to(device)[source]#
Move the Wave object to a specific device (CPU or CUDA).
This method transfers the internal audio tensor to the specified device and returns self to enable method chaining. The device transfer uses PyTorch’s standard tensor movement mechanism.
- Parameters:
device ({"cpu", "cuda"}) – Target device to move the Wave object to.
- Returns:
Returns self for method chaining support.
- Return type:
Examples
Move to GPU:
>>> wave = Wave.from_file("audio.wav") >>> wave.to("cuda")
Conditional device selection:
>>> import torch >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> wave = Wave.from_file("audio.wav").to(device)
Method chaining pattern:
>>> from torchfx.filter import iir >>> result = (Wave.from_file("input.wav") ... .to("cuda") ... | iir.LoButterworth(1000))
Process on GPU, then move back to CPU:
>>> wave = Wave.from_file("audio.wav") >>> result = (wave.to("cuda") | some_filter).to("cpu")
Notes
Unlike most Wave methods, to() modifies the Wave in place but returns self to support fluent interface patterns. Effects and filters applied via the pipeline operator will automatically inherit the Wave’s device location.
- transform(func, *args, **kwargs)[source]#
Apply a functional transformation to the audio signal tensor.
This method applies an arbitrary function to the audio tensor, creating a new Wave object with the transformed signal while preserving the sampling frequency. The original Wave object remains unchanged, following an immutability pattern.
- Parameters:
func (Callable[..., Tensor]) – Function that takes a tensor as its first argument and returns a tensor. Can be any PyTorch or torchaudio function, or a custom callable.
*args – Additional positional arguments passed to func after the tensor.
**kwargs – Additional keyword arguments passed to func.
- Returns:
New Wave object with the transformed signal and the same sampling frequency.
- Return type:
Examples
Apply FFT transformation:
>>> import torch >>> wave = Wave.from_file("audio.wav") >>> freq_domain = wave.transform(torch.fft.fft)
Apply normalization:
>>> def normalize_peak(tensor): ... return tensor / tensor.abs().max() >>> normalized = wave.transform(normalize_peak)
Apply torchaudio transforms with parameters:
>>> import torchaudio.transforms as T >>> # Resample to 16kHz (requires passing sample rate) >>> resampled = wave.transform( ... T.Resample(wave.fs, 16000).forward ... )
Apply custom function with arguments:
>>> def add_noise(tensor, noise_level=0.01): ... noise = torch.randn_like(tensor) * noise_level ... return tensor + noise >>> noisy = wave.transform(add_noise, noise_level=0.05)
Chain transformations:
>>> wave = Wave.from_file("audio.wav") >>> result = (wave ... .transform(lambda x: x / x.abs().max()) # Normalize ... .transform(torch.fft.fft) # FFT ... .transform(torch.fft.ifft) # IFFT ... .transform(torch.real)) # Extract real part
Notes
The transform method creates a new Wave object rather than modifying in place, supporting functional programming patterns. The original Wave is unchanged.
The function must accept a tensor as its first argument and return a tensor. The returned tensor shape should maintain the (channels, samples) convention, though the number of samples may change.
See also
__or__Pipeline operator for applying nn.Module effects and filters
toMove the Wave to a different device before transformation
FX#
The FX base class is the foundation for all effects and filters.
- class torchfx.FX(*args, **kwargs)[source]#
-
Abstract base class for all audio effects and filters.
FX serves as the foundation for all effects in torchfx, combining PyTorch’s nn.Module with abstract base class requirements. This design ensures effects are compatible with PyTorch’s module system while enforcing a consistent interface across all effect implementations.
All effects must implement the abstract __init__ and forward methods. The forward method receives audio tensors of shape (…, time) and returns processed tensors.
Inheriting from nn.Module provides: - GPU/CPU device management (.to(), .cuda(), .cpu()) - Parameter and buffer registration - Integration with nn.Sequential for effect chaining - Serialization support (state_dict, load_state_dict) - Gradient computation (when not using @torch.no_grad())
- Parameters:
- forward(x: Tensor) Tensor[source]#
Process input tensor and return transformed output. Must be implemented by all subclasses.
Notes
When creating custom effects:
Always call super().__init__() in your constructor
Implement forward() to process tensors of shape (…, time)
Use @torch.no_grad() decorator for inference-only effects
Validate parameters in __init__ using assertions
For sample-rate dependent effects, accept optional fs parameter
The FX base class uses the strategy pattern for extensibility. Effects can accept strategy objects to customize processing behavior without modifying the core effect implementation.
See also
GainVolume adjustment effect
NormalizeAmplitude normalization effect
ReverbReverb effect using feedback delay network
DelayMulti-tap delay effect with BPM synchronization
Examples
Create a simple custom effect:
>>> import torch >>> from torchfx.effect import FX >>> >>> class SimpleGain(FX): ... def __init__(self, gain: float) -> None: ... super().__init__() ... assert gain > 0, "Gain must be positive" ... self.gain = gain ... ... @torch.no_grad() ... def forward(self, waveform: torch.Tensor) -> torch.Tensor: ... return waveform * self.gain
Use in a pipeline:
>>> import torchfx as fx >>> wave = fx.Wave.from_file("audio.wav") >>> effect = SimpleGain(0.5) >>> processed = wave | effect
Chain multiple effects:
>>> result = wave | SimpleGain(0.5) | fx.Normalize(peak=1.0)
Create effects with strategies:
>>> from abc import ABC, abstractmethod >>> >>> class ProcessingStrategy(ABC): ... @abstractmethod ... def __call__(self, waveform: torch.Tensor) -> torch.Tensor: ... pass >>> >>> class StrategyEffect(FX): ... def __init__(self, strategy: ProcessingStrategy) -> None: ... super().__init__() ... self.strategy = strategy ... ... def forward(self, waveform: torch.Tensor) -> torch.Tensor: ... return self.strategy(waveform)
References
For detailed examples of custom effect creation, including multi-channel processing and the strategy pattern, see the “Creating Custom Effects” wiki page.
Methods
Define the computation performed at every call.
- abstract forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.