API

TorchFX provides two main classes: torchfx.Wave and torchfx.FX. The torchfx.Wave class is used to handle audio data, while the torchfx.FX class is used to apply various audio effects and transformations. The library also provides a set of built-in effects and filters that can be easily applied to audio data.

class torchfx.Wave(ys: _Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], fs: int, device: 'cpu' | 'cuda' | device = 'cpu')[source]

Bases: object

A discrete time waveform.

ys

The signal.

Type:

Tensor

fs

The sampling frequency.

Type:

int

channels() int[source]

Return the number of channels of the wave.

Returns:

The number of channels of the wave.

Return type:

int

property device : 'cpu' | 'cuda' | device

Print the device where is located this object, if there’s an assignment move the object to that device.

See also

Wave.to

duration(unit: 'sec' | 'ms') float[source]

Return the length of the wave in seconds or milliseconds.

Parameters:
unit : {"sec", "ms"}

The unit of time to return the duration in.

Returns:

The duration of the wave in the specified unit of time.

Return type:

float

Examples

>>> wave = Wave.from_file("path/to/file.wav")
>>> wave.duration("sec")
3.0
>>> wave.duration("ms")
3000.0
classmethod from_file(path: str | Path, *args, **kwargs) Wave[source]

Instantiate a wave from an audio file.

Parameters:
path : str or Path

The path to the audio file.

*args

Additional arguments to pass to torchaudio.load.

**kwargs

Additional keyword arguments to pass to torchaudio.load.

Returns:

The wave object.

Return type:

Wave

Examples

>>> wave = Wave.from_file("path/to/file.wav")
get_channel(index: int) Wave[source]

Return a specific channel of the wave.

Parameters:
index : int

The index of the channel to return.

Returns:

The wave object with only the specified channel.

Return type:

Wave

Examples

>>> wave = Wave.from_file("path/to/file.wav")
>>> wave.get_channel(0)
classmethod merge(waves: Sequence[Wave], split_channels: bool = False) Wave[source]

Merge multiple waves into a single wave.

Parameters:
waves : Sequence[Wave]

The waves to merge.

split_channels : bool, optional

If False, the channels of the waves will be merged into a single channel. If True, the channels will be merged into multiple channels. Default is False.

Returns:

The merged wave object.

Return type:

Wave

Examples

>>> wave1 = Wave.from_file("path/to/file1.wav")
>>> wave2 = Wave.from_file("path/to/file2.wav")
>>> merged_wave = wave1.merge([wave2])
to(device: 'cpu' | 'cuda' | device) Self[source]

Move the wave object to a specific device (cpu or cuda).

Parameters:
device : {"cpu", "cuda"}

The device to move the wave object to.

Returns:

The wave object.

Return type:

Wave

Examples

>>> wave = Wave.from_file("path/to/file.wav")
>>> wave.to("cuda")
transform(func: Callable[[...], Tensor], *args, **kwargs) Wave[source]

Apply a functional transformation to the signal.

Parameters:
func : Callable[..., Tensor]

The function to apply to the signal.

Returns:

A new wave object with the transformed signal.

Return type:

Wave

Examples

>>> wave = Wave.from_file("path/to/file.wav")
>>> wave.transform(torch.fft.fft)
class torchfx.FX(*args, **kwargs)[source]

Bases: Module, ABC

Abstract base class for all effects. This class defines the interface for all effects in the library. It inherits from torch.nn.Module and provides the basic structure for implementing effects.

abstract forward(x: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

The already implemented filters are disponible under the filter module:

class torchfx.filter.AllPass(cutoff: float, Q: float, fs: int | None = None)[source]

Bases: IIR

All pass filter.

compute_coefficients() None[source]

Compute the filter coefficients.

class torchfx.filter.Butterworth(btype: str, cutoff: float, order: int = 4, order_scale: 'db' | 'linear' = 'linear', fs: int | None = None, a: Sequence | None = None, b: Sequence | None = None)[source]

Bases: IIR

Butterworth filter.

compute_coefficients() None[source]

Compute the filter coefficients.

class torchfx.filter.Chebyshev1(btype: str, cutoff: float, order: int = 4, ripple: float = 0.1, fs: int | None = None, a: Sequence | None = None, b: Sequence | None = None)[source]

Bases: IIR

Chebyshev type 1 filter.

compute_coefficients() None[source]

Compute the filter coefficients.

class torchfx.filter.Chebyshev2(btype: str, cutoff: float, order: int = 4, ripple: float = 0.1, fs: int | None = None)[source]

Bases: IIR

Chebyshev type 2 filter.

compute_coefficients() None[source]

Compute the filter coefficients.

class torchfx.filter.DesignableFIR(cutoff: float | Sequence[float], num_taps: int, fs: int | None = None, pass_zero: bool = True, window: 'hann' | 'hamming' | 'blackman' | 'kaiser' | 'boxcar' | 'bartlett' | 'flattop' | 'parzen' | 'bohman' | 'nuttall' | 'barthann' = 'hamming')[source]

Bases: FIR

FIR filter designed using scipy.signal.firwin.

cutoff

Cutoff frequency or frequencies (in Hz) for the filter.

Type:

float | Sequence[float]

num_taps

Number of taps (filter order) for the FIR filter.

Type:

int

fs

Sampling frequency (in Hz) of the input signal. If None, the filter will not be designed.

Type:

int | None

pass_zero

If True, the filter will be a lowpass filter. If False, it will be a highpass filter.

Type:

bool

window

Window type to use for the FIR filter design. Default is “hamming”.

Type:

WindowType

compute_coefficients() None[source]

Compute the filter coefficients.

class torchfx.filter.FIR(b: _Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes])[source]

Bases: AbstractFilter

Efficient FIR filter using conv1d. Supports [T], [C, T], [B, C, T].

compute_coefficients() None[source]

Compute the filter coefficients.

forward(x: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class torchfx.filter.HiButterworth(cutoff: float, order: int = 5, order_scale: 'db' | 'linear' = 'linear', fs: int | None = None)[source]

Bases: Butterworth

High-pass filter.

class torchfx.filter.HiChebyshev1(cutoff: float, order: int = 4, ripple: float = 0.1, fs: int | None = None)[source]

Bases: Chebyshev1

High-pass Chebyshev type 1 filter.

class torchfx.filter.HiChebyshev2(cutoff: float, order: int = 4, ripple: float = 0.1, fs: int | None = None)[source]

Bases: Chebyshev2

High-pass Chebyshev type 2 filter.

class torchfx.filter.HiLinkwitzRiley(cutoff: float, order: int = 4, order_scale: 'db' | 'linear' = 'linear', fs: int | None = None)[source]

Bases: LinkwitzRiley

High-pass Linkwitz-Riley filter.

class torchfx.filter.HiShelving(cutoff: float, q: float, gain: float, gain_scale: 'db' | 'linear' = 'linear', fs: int | None = None)[source]

Bases: Shelving

High pass shelving filter.

compute_coefficients() None[source]

Compute the filter coefficients.

class torchfx.filter.IIR(fs: int | None = None)[source]

Bases: AbstractFilter

IIR filter. This class implements the IIR filter interface. It is an abstract class that provides the basic structure for implementing IIR filters. It inherits from AbstractFilter and provides the basic structure for implementing IIR filters.

a

The filter’s numerator coefficients.

Type:

Sequence

b

The filter’s denominator coefficients.

Type:

Sequence

fs

The sampling frequency of the filter.

Type:

int | None

cutoff

The cutoff frequency of the filter.

Type:

float

forward(x: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

move_coeff(device, dtype=torch.float32)[source]

Move the filter coefficients to the specified device and dtype.

class torchfx.filter.LinkwitzRiley(btype: str, cutoff: float, order: int = 4, order_scale: 'db' | 'linear' = 'linear', fs: int | None = None)[source]

Bases: IIR

Linkwitz-Riley filter.

This filter is created by cascading two identical Butterworth filters. The resulting filter has an order that is twice the order of the base Butterworth filter and a -6 dB gain at the cutoff frequency. The order of a Linkwitz-Riley filter must be an even integer.

compute_coefficients() None[source]

Compute the filter coefficients.

The method calculates the coefficients for a Butterworth filter of half the specified order and then cascades it with itself by convolving the numerator and denominator coefficients.

class torchfx.filter.LoButterworth(cutoff: float, order: int = 5, order_scale: 'db' | 'linear' = 'linear', fs: int | None = None)[source]

Bases: Butterworth

Low-pass filter.

class torchfx.filter.LoChebyshev1(cutoff: float, order: int = 4, ripple: float = 0.1, fs: int | None = None)[source]

Bases: Chebyshev1

Low-pass Chebyshev type 1 filter.

class torchfx.filter.LoChebyshev2(cutoff: float, order: int = 4, ripple: float = 0.1, fs: int | None = None)[source]

Bases: Chebyshev2

Low-pass Chebyshev type 2 filter.

class torchfx.filter.LoLinkwitzRiley(cutoff: float, order: int = 4, order_scale: 'db' | 'linear' = 'linear', fs: int | None = None)[source]

Bases: LinkwitzRiley

Low-pass Linkwitz-Riley filter.

class torchfx.filter.LoShelving(cutoff: float, q: float, fs: int | None = None, a: Sequence | None = None, b: Sequence | None = None)[source]

Bases: Shelving

Low pass shelving filter.

class torchfx.filter.Notch(cutoff: float, Q: float, fs: int | None = None)[source]

Bases: IIR

Notch filter.

compute_coefficients() None[source]

Compute the filter coefficients.

The already implemented effects are disponible under the effects module:

Base class for all effects.

class torchfx.effects.FX(*args, **kwargs)[source]

Bases: Module, ABC

Abstract base class for all effects. This class defines the interface for all effects in the library. It inherits from torch.nn.Module and provides the basic structure for implementing effects.

abstract forward(x: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class torchfx.effects.Gain(gain: float, gain_type: str = 'amplitude', clamp: bool = False)[source]

Bases: FX

Adjust volume of waveform.

This effect is the same as torchaudio.transforms.Vol, but it adds the option to clamp or not the output waveform.

Parameters:
(float) : gain

(str) : gain_type

(bool) : clamp

Example

>>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
>>> transform = transforms.Vol(gain=0.5, gain_type="amplitude")
>>> quieter_waveform = transform(waveform)

See also

torchaudio.transforms.Vol

Transform to apply gain to a waveform.

Notes

This class is based on torchaudio.transforms.Vol, licensed under the BSD 2-Clause License. See licenses.torchaudio.BSD-2-Clause.txt for details.

forward(waveform: Tensor) Tensor[source]
Parameters:
waveform : Tensor

Tensor of audio of dimension (…, time).

Returns:

Tensor of audio of dimension (…, time).

Return type:

Tensor

class torchfx.effects.NormalizationStrategy[source]

Bases: ABC

Abstract base class for normalization strategies.

class torchfx.effects.Normalize(peak: float = 1.0, strategy: NormalizationStrategy | None = None)[source]

Bases: FX

Normalize the waveform to a given peak value using a selected strategy.

Parameters:
peak : float

The peak value to normalize to. Default is 1.0.

Example
>>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
>>> transform = transforms.Normalize(peak=0.5)
>>> normalized_waveform = transform(waveform)
forward(waveform: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class torchfx.effects.PeakNormalizationStrategy[source]

Bases: NormalizationStrategy

Normalization to the absolute peak value.

class torchfx.effects.PerChannelNormalizationStrategy[source]

Bases: NormalizationStrategy

Normalize each channel independently to its own peak.

class torchfx.effects.PercentileNormalizationStrategy(percentile: float = 99.0)[source]

Bases: NormalizationStrategy

Normalization using a percentile of absolute values.

class torchfx.effects.RMSNormalizationStrategy[source]

Bases: NormalizationStrategy

Normalization to Root Mean Square (RMS) energy.

class torchfx.effects.Reverb(delay: int = 4410, decay: float = 0.5, mix: float = 0.5)[source]

Bases: FX

Apply a simple reverb effect using a feedback delay network.

Parameters:
delay : int

Delay in samples for the feedback comb filter.

decay : float

Feedback decay factor (0 < decay < 1).

mix : float

Wet/dry mix (0 = dry, 1 = wet).

The reverb effect is computed as:

y[n] = (1 - mix) * x[n] + mix * (x[n] + decay * x[n - delay])

where:
  • x[n] is the input signal,

  • y[n] is the output signal,

  • delay is the number of samples for the delay,

  • decay is the feedback decay factor,

  • mix is the wet/dry mix parameter.

Example
>>> reverb = Reverb(delay=4410, decay=0.5, mix=0.3)
>>> reverberated = reverb(waveform)
forward(waveform: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.