Source code for torchfx.realtime.stream

"""Chunk-based stream processing for large audio files.

This module provides the ``StreamProcessor`` class for processing audio
files in chunks without loading the entire file into memory. This is
essential for processing files that are too large to fit in memory or
for streaming applications.

Classes
-------
StreamProcessor
    Chunk-based file processor for large audio files.

Examples
--------
>>> from torchfx.realtime import StreamProcessor
>>> from torchfx.effect import Gain
>>> processor = StreamProcessor(effects=[Gain(0.5)], chunk_size=65536)
>>> # processor.process_file("large_input.wav", "output.wav")

"""

from __future__ import annotations

from collections.abc import Callable, Generator, Iterable, Sequence
from pathlib import Path
from typing import cast

import soundfile as sf  # type: ignore[import-untyped]
import torch
import torchaudio
from torch import Tensor, nn

from torchfx.effect import FX
from torchfx.filter.__base import AbstractFilter
from torchfx.logging import get_logger
from torchfx.validation import validate_positive

_logger = get_logger("realtime.stream")


[docs] class StreamProcessor: """Process audio files in chunks without loading the entire file. Reads audio in configurable chunk sizes, applies an effect chain to each chunk, and writes output progressively. Supports an overlap parameter for effects that need context beyond chunk boundaries. Parameters ---------- effects : Sequence[FX] | nn.Sequential Chain of effects to apply in order. chunk_size : int Number of samples per processing chunk. Default is 65536. overlap : int Number of overlap samples between chunks. Default is 0. device : str Processing device (``"cpu"`` or ``"cuda"``). Default is ``"cpu"``. Examples -------- >>> from torchfx.realtime import StreamProcessor >>> from torchfx.effect import Gain >>> processor = StreamProcessor(effects=[Gain(0.5)]) """ def __init__( self, effects: Sequence[FX] | nn.Sequential, chunk_size: int = 65536, overlap: int = 0, device: str = "cpu", ) -> None: validate_positive(chunk_size, "chunk_size") if overlap < 0: raise ValueError(f"Overlap must be non-negative, got {overlap}") if overlap >= chunk_size: raise ValueError(f"Overlap ({overlap}) must be less than chunk_size ({chunk_size})") self._effects: list[FX] = self._normalize_effects(effects) self._chunk_size = chunk_size self._overlap = overlap self._device = device def __enter__(self) -> StreamProcessor: """Return self for use as context manager. The context manager form is a convenience for scoping the processor lifetime. No special start/stop is needed. Examples -------- >>> with StreamProcessor(effects=[Gain(0.5)]) as processor: ... processor.process_file("in.wav", "out.wav") """ return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object, ) -> None: """Clean up on context exit.""" pass @staticmethod def _normalize_effects(effects: Sequence[FX] | nn.Sequential) -> list[FX]: modules: Iterable[FX] = ( cast(Iterable[FX], effects) if isinstance(effects, nn.Sequential) else effects ) normalized: list[FX] = [] for effect in modules: if not isinstance(effect, FX): raise TypeError("All effects must inherit from FX when used in StreamProcessor") normalized.append(effect) return normalized def _configure_effects(self, fs: int) -> None: """Set sample rate and compute coefficients for all effects. Resets coefficients if the sample rate has changed since the last configuration. Parameters ---------- fs : int Sample rate in Hz. Raises ------ ValueError If a filter's cutoff frequency exceeds the Nyquist frequency. """ nyquist = fs / 2.0 for effect in self._effects: if hasattr(effect, "fs"): current_fs = effect.fs if current_fs != fs: effect.fs = fs # type: ignore # Force coefficient recomputation with new sample rate if isinstance(effect, AbstractFilter): effect.compute_coefficients() reset_state = getattr(effect, "reset_state", None) if callable(reset_state): reset_state() # Validate cutoff before computing coefficients if isinstance(effect, AbstractFilter) and hasattr(effect, "cutoff"): cutoff = effect.cutoff if isinstance(cutoff, int | float) and cutoff >= nyquist: raise ValueError( f"{type(effect).__name__} cutoff ({cutoff} Hz) must be " f"below the Nyquist frequency ({nyquist} Hz) for sample rate " f"{fs} Hz. Reduce the cutoff or use a higher sample rate file." ) if isinstance(effect, AbstractFilter) and not effect._has_computed_coeff: effect.compute_coefficients()
[docs] @torch.no_grad() def process_file( self, input_path: str | Path, output_path: str | Path, format: str | None = None, # noqa: A002 subtype: str | None = None, ) -> None: """Process an audio file chunk by chunk. Reads the input file in chunks, applies the effect chain to each chunk, and writes the result to the output file. Parameters ---------- input_path : str | Path Path to the input audio file. output_path : str | Path Path to the output audio file. format : str | None Output format (e.g., ``"WAV"``, ``"FLAC"``). Inferred from extension if None. subtype : str | None Output subtype (e.g., ``"PCM_16"``, ``"FLOAT"``). Uses default for format if None. Examples -------- >>> from torchfx.realtime import StreamProcessor >>> from torchfx.effect import Gain >>> processor = StreamProcessor(effects=[Gain(0.5)]) >>> # processor.process_file("input.wav", "output.wav") """ input_path = Path(input_path) output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) # Get file metadata without loading info = torchaudio.info(str(input_path)) fs = info.sample_rate num_frames = info.num_frames channels = info.num_channels _logger.info( "Processing %s: %d frames, %d channels, %dHz", input_path, num_frames, channels, fs, ) self._configure_effects(fs) # Determine output format if format is None: ext = output_path.suffix.lower() format_map = {".wav": "WAV", ".flac": "FLAC", ".ogg": "OGG"} format = format_map.get(ext, "WAV") # noqa: A001 if subtype is None: subtype = "FLOAT" if format == "WAV" else None # Process and write chunks with sf.SoundFile( str(output_path), mode="w", samplerate=fs, channels=channels, format=format, subtype=subtype, ) as out_file: hop_size = self._chunk_size - self._overlap offset = 0 while offset < num_frames: # Read chunk (with overlap) read_size = min(self._chunk_size, num_frames - offset) waveform, sample_rate = torchaudio.load( str(input_path), frame_offset=offset, num_frames=read_size, ) # Move to processing device if self._device != "cpu": waveform = waveform.to(self._device) # Apply effect chain for effect in self._effects: if not isinstance(effect, nn.Module): raise TypeError("Effects must inherit from torch.nn.Module") call_effect = cast(Callable[[Tensor], Tensor], effect) waveform = call_effect(waveform) # Move back to CPU for writing if self._device != "cpu": waveform = waveform.cpu() # For overlap mode, only write the non-overlapping part if self._overlap > 0 and offset > 0: write_data = waveform[:, self._overlap :] else: write_data = waveform # Write: convert (channels, frames) -> (frames, channels) out_file.write(write_data.numpy().T) offset += hop_size _logger.debug("Processed %d / %d frames", min(offset, num_frames), num_frames) _logger.info("Output written to %s", output_path)
[docs] @torch.no_grad() def process_chunks( self, input_path: str | Path, ) -> Generator[Tensor, None, None]: """Yield processed chunks as tensors. Generator API for streaming to another process, network, or real-time playback. Parameters ---------- input_path : str | Path Path to the input audio file. Yields ------ Tensor Processed audio chunks of shape ``(channels, chunk_size)``. Examples -------- >>> from torchfx.realtime import StreamProcessor >>> from torchfx.effect import Gain >>> processor = StreamProcessor(effects=[Gain(0.5)]) >>> # for chunk in processor.process_chunks("input.wav"): >>> # print(chunk.shape) """ input_path = Path(input_path) info = torchaudio.info(str(input_path)) fs = info.sample_rate num_frames = info.num_frames self._configure_effects(fs) hop_size = self._chunk_size - self._overlap offset = 0 while offset < num_frames: read_size = min(self._chunk_size, num_frames - offset) waveform, _sample_rate = torchaudio.load( str(input_path), frame_offset=offset, num_frames=read_size, ) if self._device != "cpu": waveform = waveform.to(self._device) for effect in self._effects: if not isinstance(effect, nn.Module): raise TypeError("Effects must inherit from torch.nn.Module") call_effect = cast(Callable[[Tensor], Tensor], effect) waveform = call_effect(waveform) if self._device != "cpu": waveform = waveform.cpu() # For overlap mode, only yield the non-overlapping part if self._overlap > 0 and offset > 0: yield waveform[:, self._overlap :] else: yield waveform offset += hop_size
@property def chunk_size(self) -> int: """Processing chunk size in samples.""" return self._chunk_size @property def overlap(self) -> int: """Overlap between chunks in samples.""" return self._overlap @property def effects(self) -> list[FX]: """The current effect chain.""" return self._effects