Core Classes#

The core module provides the fundamental classes for working with audio in TorchFX.

Wave#

The Wave class handles audio data representation and I/O operations.

class torchfx.Wave(ys, fs, device='cpu', metadata={})[source]#

Bases: object

A discrete-time waveform representing multi-channel audio signals.

The Wave class is the fundamental data structure in torchfx, wrapping a PyTorch tensor containing audio samples along with its sampling frequency. It provides device management (CPU/CUDA), file I/O integration with torchaudio, functional transformations, and a pipeline operator for chaining effects and filters.

Parameters:
ys#

Audio signal tensor with shape (channels, samples). Each row represents a channel, and each column represents a time sample.

Type:

Tensor

fs#

Sampling frequency in Hz (samples per second). Used by filters and effects to compute time-domain parameters.

Type:

int

metadata#

Optional metadata dictionary containing audio file information such as num_frames, num_channels, bits_per_sample, and encoding.

Type:

dict

device#

Current device location (“cpu” or “cuda”) where the tensor resides. Read-only property; use the setter or to() method to move devices.

Type:

Device

Examples

Create a Wave from an array:

>>> import torch
>>> from torchfx import Wave
>>> samples = torch.randn(1, 44100)  # 1 second of mono audio at 44.1kHz
>>> wave = Wave(samples, fs=44100)

Load from an audio file:

>>> wave = Wave.from_file("audio.wav")
>>> print(f"Channels: {wave.channels()}, Duration: {wave.duration('sec')}s")

Process on GPU with pipeline operator:

>>> from torchfx.filter import iir
>>> wave = Wave.from_file("input.wav").to("cuda")
>>> result = wave | iir.LoButterworth(1000, order=4)

Multi-channel processing:

>>> stereo = Wave.from_file("stereo.wav")
>>> left = stereo.get_channel(0)
>>> right = stereo.get_channel(1)
>>> # Process channels independently
>>> processed_left = left | some_effect
>>> processed_right = right | other_effect
>>> result = Wave.merge([processed_left, processed_right], split_channels=True)

Notes

The Wave class follows an immutability pattern for most operations. Methods like transform(), __or__(), and get_channel() return new Wave objects rather than modifying in place. Only device management methods (to(), device setter) modify the Wave in place while returning self for method chaining.

The pipe operator (|) automatically configures FX modules by setting the sampling frequency if not already set and computing filter coefficients before first use. This eliminates boilerplate and prevents common configuration errors.

See also

from_file

Load a Wave from an audio file

merge

Combine multiple Wave objects into one

transform

Apply functional transformations to the signal

to

Move the Wave to a different device

Methods

from_file

Load a Wave object from an audio file.

save

Save the wave to an audio file.

to

Move the Wave object to a specific device (CPU or CUDA).

transform

Apply a functional transformation to the audio signal tensor.

get_channel

Extract a specific channel as a new Wave object.

merge

Combine multiple Wave objects into a single Wave.

channels()[source]#

Return the number of audio channels in the wave.

Returns:

Number of channels. Returns 1 for mono, 2 for stereo, or higher values for multi-channel audio.

Return type:

int

Examples

Check channel count:

>>> wave = Wave.from_file("audio.wav")
>>> print(f"This file has {wave.channels()} channel(s)")

Process based on channel count:

>>> wave = Wave.from_file("audio.wav")
>>> if wave.channels() == 1:
...     print("Mono audio")
... elif wave.channels() == 2:
...     print("Stereo audio")
... else:
...     print(f"Multi-channel audio with {wave.channels()} channels")

Extract all channels:

>>> wave = Wave.from_file("multichannel.wav")
>>> channel_list = [wave.get_channel(i) for i in range(wave.channels())]

Notes

This method returns the first dimension of the ys tensor shape. The tensor follows the convention (channels, samples), so channels() returns ys.shape[0].

See also

get_channel

Extract a specific channel as a new Wave object

merge

Combine multiple Wave objects with different merge strategies

duration(unit)[source]#

Calculate the duration of the audio signal.

Computes the time length of the audio based on the number of samples and the sampling frequency. The duration can be returned in either seconds or milliseconds.

Parameters:

unit ({"sec", "ms"}) – Unit for the returned duration. Use “sec” for seconds or “ms” for milliseconds.

Returns:

Duration in the specified time unit. The return type is annotated as Second (float) when unit=”sec” or Millisecond (float) when unit=”ms”.

Return type:

float

Examples

Get duration in seconds:

>>> wave = Wave.from_file("audio.wav")
>>> duration_sec = wave.duration("sec")
>>> print(f"Duration: {duration_sec:.2f} seconds")

Get duration in milliseconds:

>>> wave = Wave.from_file("audio.wav")
>>> duration_ms = wave.duration("ms")
>>> print(f"Duration: {duration_ms:.0f} ms")

Compare durations:

>>> wave1 = Wave.from_file("short.wav")
>>> wave2 = Wave.from_file("long.wav")
>>> if wave1.duration("sec") < wave2.duration("sec"):
...     print("wave1 is shorter")

Calculate processing time estimate:

>>> wave = Wave.from_file("audio.wav")
>>> duration = wave.duration("sec")
>>> # Estimate processing time (example: 10x realtime)
>>> estimated_time = duration * 10
>>> print(f"Estimated processing: {estimated_time:.2f} seconds")

Use in validation:

>>> wave = Wave.from_file("audio.wav")
>>> max_duration_sec = 60.0
>>> if wave.duration("sec") > max_duration_sec:
...     print("Audio file is too long")

Notes

The duration is calculated using the formula:

duration = (number_of_samples / sampling_frequency) * multiplier

Where: - number_of_samples = len(self) = self.ys.shape[1] - sampling_frequency = self.fs - multiplier = 1000 for milliseconds, 1 for seconds

For example, a Wave with 44100 samples at 44100 Hz has a duration of: - 1.0 second (44100 / 44100 * 1) - 1000.0 milliseconds (44100 / 44100 * 1000)

The duration is independent of the number of channels; it represents the time length of the audio signal.

See also

__len__

Get the number of samples in the wave

channels

Get the number of channels

classmethod from_file(path, *args, **kwargs)[source]#

Load a Wave object from an audio file.

This classmethod uses torchaudio.load to read audio files, automatically detecting the format and extracting metadata. Supported formats include WAV, MP3, FLAC, OGG, and others depending on the available torchaudio backend.

Parameters:
  • path (str or Path) – Path to the audio file to load. Can be a string or pathlib.Path object.

  • *args – Additional positional arguments passed to torchaudio.load.

  • **kwargs – Additional keyword arguments passed to torchaudio.load. Common options include frame_offset, num_frames, normalize, channels_first, and format.

Returns:

New Wave object containing the loaded audio data, sampling frequency, and extracted metadata.

Return type:

Wave

Examples

Load a WAV file:

>>> wave = Wave.from_file("audio.wav")
>>> print(f"Loaded {wave.channels()} channels at {wave.fs}Hz")

Load a specific portion of a file:

>>> # Load 1 second starting at 2 seconds
>>> wave = Wave.from_file("long_audio.wav", frame_offset=88200, num_frames=44100)

Load and check metadata:

>>> wave = Wave.from_file("audio.flac")
>>> print(wave.metadata)
{'num_frames': 220500, 'num_channels': 2, 'bits_per_sample': 16, ...}

Load different formats:

>>> wav_wave = Wave.from_file("audio.wav")
>>> mp3_wave = Wave.from_file("audio.mp3")
>>> flac_wave = Wave.from_file("audio.flac")

Load with normalization:

>>> # Normalize to [-1, 1] range
>>> wave = Wave.from_file("audio.wav", normalize=True)

Notes

The method automatically extracts metadata including num_frames, num_channels, bits_per_sample, and encoding when available. If metadata extraction fails, an empty metadata dictionary is used instead.

The loaded audio tensor will be on CPU by default. Use the to() method or device parameter in subsequent processing to move to GPU:

>>> wave = Wave.from_file("audio.wav").to("cuda")

Format support depends on the torchaudio backend (SoX or FFmpeg). Check torchaudio documentation for your installation’s supported formats.

See also

__init__

Direct constructor for creating Wave from array data

save

Save a Wave object to an audio file

torchaudio.load

Underlying function used for loading

get_channel(index)[source]#

Extract a specific channel as a new Wave object.

This method creates a new Wave containing only the specified channel from the original multi-channel audio. The new Wave has the same sampling frequency and can be processed independently.

Parameters:

index (int) – Zero-based index of the channel to extract. For stereo audio, use 0 for left channel and 1 for right channel.

Returns:

New Wave object containing only the specified channel with shape (1, samples).

Return type:

Wave

Examples

Extract left and right channels from stereo:

>>> stereo = Wave.from_file("stereo.wav")
>>> left_channel = stereo.get_channel(0)
>>> right_channel = stereo.get_channel(1)

Process channels independently:

>>> from torchfx.filter import iir
>>> stereo = Wave.from_file("stereo.wav")
>>> # Apply different filters to each channel
>>> left = stereo.get_channel(0) | iir.LoButterworth(1000)
>>> right = stereo.get_channel(1) | iir.HiButterworth(1000)
>>> # Merge back to stereo
>>> result = Wave.merge([left, right], split_channels=True)

Extract all channels as a list:

>>> wave = Wave.from_file("multichannel.wav")
>>> channels = [wave.get_channel(i) for i in range(wave.channels())]

Process mono from stereo by averaging:

>>> stereo = Wave.from_file("stereo.wav")
>>> left = stereo.get_channel(0)
>>> right = stereo.get_channel(1)
>>> # Mix to mono (this uses merge to sum channels)
>>> mono = Wave.merge([left, right], split_channels=False)

Notes

The returned Wave object is independent of the original. Modifications to the returned Wave do not affect the original multi-channel Wave.

The index must be within the valid range [0, channels()-1]. Python’s standard indexing rules apply, so negative indices are supported (e.g., -1 for the last channel).

See also

channels

Get the total number of channels

merge

Combine multiple Wave objects back into multi-channel audio

classmethod merge(waves, split_channels=False)[source]#

Combine multiple Wave objects into a single Wave.

This classmethod provides two merging strategies: mixing (summing waveforms element-wise) or channel concatenation (preserving each wave as separate channels). All waves must have the same sampling frequency.

Parameters:
  • waves (Sequence[Wave]) – Sequence of Wave objects to merge. Must contain at least one wave.

  • split_channels (bool, optional) – Determines the merge strategy: - False (default): Mix waves by summing them element-wise - True: Concatenate waves along the channel dimension

Returns:

New Wave object containing the merged audio with the same sampling frequency as the input waves.

Return type:

Wave

Raises:

ValueError – If no waves are provided or if waves have different sampling frequencies.

Examples

Mix two mono waves into one (sum strategy):

>>> wave1 = Wave.from_file("voice.wav")
>>> wave2 = Wave.from_file("music.wav")
>>> mixed = Wave.merge([wave1, wave2], split_channels=False)
>>> # Result: mono wave with voice and music mixed together

Combine two mono waves into stereo (concatenate strategy):

>>> left = Wave.from_file("left.wav")
>>> right = Wave.from_file("right.wav")
>>> stereo = Wave.merge([left, right], split_channels=True)
>>> # Result: stereo wave with left and right channels

Merge after independent processing:

>>> from torchfx.filter import iir
>>> stereo = Wave.from_file("stereo.wav")
>>> left = stereo.get_channel(0) | iir.LoButterworth(1000)
>>> right = stereo.get_channel(1) | iir.HiButterworth(1000)
>>> result = Wave.merge([left, right], split_channels=True)

Mix multiple sources with different effects:

>>> wave1 = Wave.from_file("track1.wav") | effect1
>>> wave2 = Wave.from_file("track2.wav") | effect2
>>> wave3 = Wave.from_file("track3.wav") | effect3
>>> final_mix = Wave.merge([wave1, wave2, wave3], split_channels=False)

Create multi-channel output from mono sources:

>>> channels = [Wave.from_file(f"channel_{i}.wav") for i in range(8)]
>>> multichannel = Wave.merge(channels, split_channels=True)
>>> print(f"Created {multichannel.channels()} channel audio")

Notes

Merge Strategy Comparison:

When split_channels=False (mixing): - Waves are summed element-wise - If waves have different lengths, shorter ones are zero-padded - Result has same number of channels as input waves - Use for: audio mixing, layering multiple sounds

When split_channels=True (concatenation): - Waves are concatenated along channel dimension - Each input wave becomes a separate channel in output - Result channel count = sum of all input channel counts - Use for: creating stereo/multichannel from mono sources

Length Handling:

When merging waves of different lengths: - The output length is the maximum length among all input waves - Shorter waves are zero-padded to match the longest wave

Validation:

All input waves must have identical sampling frequencies. This is enforced because merging waves with different sampling rates would be technically invalid.

Device Compatibility:

All waves should be on the same device. The merged Wave will be on the device of the first input wave.

See also

get_channel

Extract individual channels from multi-channel audio

channels

Get the number of channels in a Wave

__init__

Constructor for creating Wave from array data

save(path, format=None, encoding=None, bits_per_sample=None, **kwargs)[source]#

Save the wave to an audio file.

Parameters:
  • path (str or Path) – The path where to save the audio file.

  • format (str, optional) – Override the audio format. If not specified, the format is inferred from the file extension. Valid values include: “wav”, “flac”.

  • encoding (str, optional) – Changes the encoding for supported formats (wav, flac). Valid values: “PCM_S” (signed int), “PCM_U” (unsigned int), “PCM_F” (float), “ULAW”, “ALAW”.

  • bits_per_sample (int, optional) – Changes the bit depth for supported formats. Valid values: 8, 16, 24, 32, 64.

  • **kwargs – Additional keyword arguments to pass to torchaudio.save.

Return type:

None

Examples

Save as WAV file:

>>> wave = Wave.from_file("input.wav")
>>> wave.save("output.wav")

Save as FLAC with specific encoding:

>>> wave.save("output.flac", encoding="PCM_S", bits_per_sample=24)

Save with high bit depth:

>>> wave.save("output.wav", encoding="PCM_F", bits_per_sample=32)

Notes

  • The method automatically creates parent directories if they don’t exist.

  • The audio data is moved to CPU before saving.

  • Supported formats depend on the available torchaudio backend.

See also

from_file

Load a wave from an audio file.

torchaudio.save

Underlying function used for saving.

to(device)[source]#

Move the Wave object to a specific device (CPU or CUDA).

This method transfers the internal audio tensor to the specified device and returns self to enable method chaining. The device transfer uses PyTorch’s standard tensor movement mechanism.

Parameters:

device ({"cpu", "cuda"}) – Target device to move the Wave object to.

Returns:

Returns self for method chaining support.

Return type:

Wave

Examples

Move to GPU:

>>> wave = Wave.from_file("audio.wav")
>>> wave.to("cuda")

Conditional device selection:

>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> wave = Wave.from_file("audio.wav").to(device)

Method chaining pattern:

>>> from torchfx.filter import iir
>>> result = (Wave.from_file("input.wav")
...     .to("cuda")
...     | iir.LoButterworth(1000))

Process on GPU, then move back to CPU:

>>> wave = Wave.from_file("audio.wav")
>>> result = (wave.to("cuda") | some_filter).to("cpu")

Notes

Unlike most Wave methods, to() modifies the Wave in place but returns self to support fluent interface patterns. Effects and filters applied via the pipeline operator will automatically inherit the Wave’s device location.

See also

device

Property for getting or setting the current device

save

Method that automatically moves to CPU before saving

transform(func, *args, **kwargs)[source]#

Apply a functional transformation to the audio signal tensor.

This method applies an arbitrary function to the audio tensor, creating a new Wave object with the transformed signal while preserving the sampling frequency. The original Wave object remains unchanged, following an immutability pattern.

Parameters:
  • func (Callable[..., Tensor]) – Function that takes a tensor as its first argument and returns a tensor. Can be any PyTorch or torchaudio function, or a custom callable.

  • *args – Additional positional arguments passed to func after the tensor.

  • **kwargs – Additional keyword arguments passed to func.

Returns:

New Wave object with the transformed signal and the same sampling frequency.

Return type:

Wave

Examples

Apply FFT transformation:

>>> import torch
>>> wave = Wave.from_file("audio.wav")
>>> freq_domain = wave.transform(torch.fft.fft)

Apply normalization:

>>> def normalize_peak(tensor):
...     return tensor / tensor.abs().max()
>>> normalized = wave.transform(normalize_peak)

Apply torchaudio transforms with parameters:

>>> import torchaudio.transforms as T
>>> # Resample to 16kHz (requires passing sample rate)
>>> resampled = wave.transform(
...     T.Resample(wave.fs, 16000).forward
... )

Apply custom function with arguments:

>>> def add_noise(tensor, noise_level=0.01):
...     noise = torch.randn_like(tensor) * noise_level
...     return tensor + noise
>>> noisy = wave.transform(add_noise, noise_level=0.05)

Chain transformations:

>>> wave = Wave.from_file("audio.wav")
>>> result = (wave
...     .transform(lambda x: x / x.abs().max())  # Normalize
...     .transform(torch.fft.fft)                # FFT
...     .transform(torch.fft.ifft)               # IFFT
...     .transform(torch.real))                  # Extract real part

Notes

The transform method creates a new Wave object rather than modifying in place, supporting functional programming patterns. The original Wave is unchanged.

The function must accept a tensor as its first argument and return a tensor. The returned tensor shape should maintain the (channels, samples) convention, though the number of samples may change.

See also

__or__

Pipeline operator for applying nn.Module effects and filters

to

Move the Wave to a different device before transformation

property device: Literal['cpu', 'cuda'] | device#

Print the device where is located this object, if there’s an assignment move the object to that device.

See also

Wave.to

FX#

The FX base class is the foundation for all effects and filters.

class torchfx.FX(*args, **kwargs)[source]#

Bases: Module, ABC

Abstract base class for all audio effects and filters.

FX serves as the foundation for all effects in torchfx, combining PyTorch’s nn.Module with abstract base class requirements. This design ensures effects are compatible with PyTorch’s module system while enforcing a consistent interface across all effect implementations.

All effects must implement the abstract __init__ and forward methods. The forward method receives audio tensors of shape (…, time) and returns processed tensors.

Inheriting from nn.Module provides: - GPU/CPU device management (.to(), .cuda(), .cpu()) - Parameter and buffer registration - Integration with nn.Sequential for effect chaining - Serialization support (state_dict, load_state_dict) - Gradient computation (when not using @torch.no_grad())

Parameters:
  • *args (tuple) – Positional arguments passed to nn.Module.

  • **kwargs (dict) – Keyword arguments passed to nn.Module.

forward(x: Tensor) Tensor[source]#

Process input tensor and return transformed output. Must be implemented by all subclasses.

Parameters:

x (Tensor)

Return type:

Tensor

Notes

When creating custom effects:

  1. Always call super().__init__() in your constructor

  2. Implement forward() to process tensors of shape (…, time)

  3. Use @torch.no_grad() decorator for inference-only effects

  4. Validate parameters in __init__ using assertions

  5. For sample-rate dependent effects, accept optional fs parameter

The FX base class uses the strategy pattern for extensibility. Effects can accept strategy objects to customize processing behavior without modifying the core effect implementation.

See also

Gain

Volume adjustment effect

Normalize

Amplitude normalization effect

Reverb

Reverb effect using feedback delay network

Delay

Multi-tap delay effect with BPM synchronization

Examples

Create a simple custom effect:

>>> import torch
>>> from torchfx.effect import FX
>>>
>>> class SimpleGain(FX):
...     def __init__(self, gain: float) -> None:
...         super().__init__()
...         assert gain > 0, "Gain must be positive"
...         self.gain = gain
...
...     @torch.no_grad()
...     def forward(self, waveform: torch.Tensor) -> torch.Tensor:
...         return waveform * self.gain

Use in a pipeline:

>>> import torchfx as fx
>>> wave = fx.Wave.from_file("audio.wav")
>>> effect = SimpleGain(0.5)
>>> processed = wave | effect

Chain multiple effects:

>>> result = wave | SimpleGain(0.5) | fx.Normalize(peak=1.0)

Create effects with strategies:

>>> from abc import ABC, abstractmethod
>>>
>>> class ProcessingStrategy(ABC):
...     @abstractmethod
...     def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
...         pass
>>>
>>> class StrategyEffect(FX):
...     def __init__(self, strategy: ProcessingStrategy) -> None:
...         super().__init__()
...         self.strategy = strategy
...
...     def forward(self, waveform: torch.Tensor) -> torch.Tensor:
...         return self.strategy(waveform)

References

For detailed examples of custom effect creation, including multi-channel processing and the strategy pattern, see the “Creating Custom Effects” wiki page.

Methods

forward

Define the computation performed at every call.

abstract forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

Tensor