Testing#
Comprehensive guide to the TorchFX testing infrastructure, including test organization, execution, patterns, and coverage requirements.
Overview#
The TorchFX testing infrastructure uses pytest as the test runner with comprehensive coverage reporting. Tests validate the correctness of audio effects, filters, and the core Wave class through unit and integration tests. The test suite emphasizes strategy pattern testing, multi-channel processing, and edge case validation.
See also
Project Structure - Project structure Benchmarking - Performance testing
Test Infrastructure Configuration#
Pytest Configuration#
The pytest configuration is defined in pyproject.toml:
Configuration |
Value |
Purpose |
|---|---|---|
|
|
Minimum pytest version required |
|
|
Strict marker validation and short tracebacks |
|
|
Directory containing test files |
|
|
Python path for importing torchfx modules |
Configuration details:
--strict-markers: Ensures only registered markers are used, preventing typos--tb=short: Provides concise traceback output for faster debuggingpythonpath = ["src"]: Allows tests to importtorchfxdirectly without installation
Coverage Configuration#
Coverage reporting is configured to track code execution:
Configuration |
Value |
Purpose |
|---|---|---|
|
|
Source directory to measure coverage |
|
|
Enable branch coverage analysis |
Branch coverage ensures that both True and False branches of conditional statements are tested, providing more thorough coverage metrics than simple line coverage.
Test Organization#
graph TB
subgraph "Test Directory Structure"
TestDir["tests/"]
TestEffects["test_effects.py<br/>Primary test file"]
end
subgraph "Test Categories in test_effects.py"
GainTests["Gain Tests<br/>test_gain_*"]
NormalizeTests["Normalize Tests<br/>test_normalize_*"]
StrategyTests["Strategy Tests<br/>test_*_strategy"]
ReverbTests["Reverb Tests<br/>test_reverb_*"]
DelayTests["Delay Tests<br/>test_delay_*"]
MusicalTimeTests["MusicalTime Tests<br/>test_musical_time_*"]
end
subgraph "Source Code Under Test"
EffectModule["src/torchfx/effect.py<br/>FX, Gain, Normalize<br/>Reverb, Delay"]
TypingModule["src/torchfx/typing.py<br/>MusicalTime"]
end
TestDir --> TestEffects
TestEffects --> GainTests
TestEffects --> NormalizeTests
TestEffects --> StrategyTests
TestEffects --> ReverbTests
TestEffects --> DelayTests
TestEffects --> MusicalTimeTests
GainTests -.->|tests| EffectModule
NormalizeTests -.->|tests| EffectModule
StrategyTests -.->|tests| EffectModule
ReverbTests -.->|tests| EffectModule
DelayTests -.->|tests| EffectModule
MusicalTimeTests -.->|tests| TypingModule
Test File Structure#
The primary test file is tests/test_effects.py, containing all tests for the effect system:
Naming Conventions#
Test functions follow consistent naming patterns:
test_<component>_<scenario>: Tests a specific scenariotest_<component>_invalid_<parameter>: Tests validation and error handlingtest_<strategy>_strategy: Tests strategy pattern implementations
Examples:
def test_gain_amplitude():
"""Tests gain in amplitude mode."""
pass
def test_normalize_invalid_peak():
"""Tests peak validation in Normalize."""
pass
def test_peak_normalization_strategy():
"""Tests PeakNormalizationStrategy."""
pass
Test Execution Flow#
sequenceDiagram
participant Dev as Developer
participant Pytest as pytest Runner
participant Test as Test Function
participant SUT as System Under Test
participant Assert as Assertion
Dev->>Pytest: pytest tests/
Pytest->>Pytest: Discover tests in tests/
Pytest->>Pytest: Apply pythonpath = ["src"]
loop For each test function
Pytest->>Test: Execute test_*()
alt Setup required
Test->>Test: Create fixtures/mocks
end
Test->>SUT: Instantiate component
Test->>SUT: Call forward() or __call__()
SUT->>Test: Return processed tensor
Test->>Assert: torch.testing.assert_close()
alt Assertion passes
Assert->>Pytest: Test passed
else Assertion fails
Assert->>Pytest: Test failed (short traceback)
end
end
Pytest->>Dev: Test results + coverage report
Running Tests#
Basic Execution#
Run all tests:
pytest tests/
Run tests with coverage:
pytest tests/ --cov=src/torchfx --cov-report=html
Run a specific test file:
pytest tests/test_effects.py
Run a specific test function:
pytest tests/test_effects.py::test_gain_amplitude
Verbose Output#
For detailed output showing each test name:
pytest tests/ -v
For even more detailed output including print statements:
pytest tests/ -vv -s
Filtering Tests#
Run tests matching a pattern:
# Run all tests with "gain" in the name
pytest tests/ -k "gain"
# Run all tests with "delay" but not "lazy"
pytest tests/ -k "delay and not lazy"
Run tests with specific markers (requires marker registration):
pytest tests/ -m "slow" # Run only slow tests
pytest tests/ -m "not slow" # Skip slow tests
Coverage Reports#
Generate terminal coverage report:
pytest tests/ --cov=src/torchfx --cov-report=term-missing
Generate HTML coverage report:
pytest tests/ --cov=src/torchfx --cov-report=html
# Open htmlcov/index.html in browser
Generate XML coverage report (for CI):
pytest tests/ --cov=src/torchfx --cov-report=xml
Test Types and Patterns#
Unit Tests#
Unit tests validate individual components in isolation with known inputs and expected outputs.
Example: Basic Gain Test
def test_gain_amplitude():
"""Test that Gain with amplitude mode multiplies waveform correctly."""
waveform = torch.tensor([0.1, -0.2, 0.3])
gain = Gain(gain=2.0, gain_type="amplitude")
out = gain(waveform)
torch.testing.assert_close(out, waveform * 2.0)
Key characteristics:
Tests single component
Predictable inputs and outputs
No external dependencies
Fast execution
Integration Tests#
Integration tests validate interaction between multiple components, such as pipeline chaining or multi-channel processing.
Example: Wave Pipeline Integration
def test_delay_lazy_fs_inference_with_wave():
"""Test that Delay automatically infers fs when used with Wave."""
from torchfx import Wave
delay = Delay(bpm=120, delay_time="1/8", feedback=0.3, mix=0.2)
assert delay.fs is None # Not yet configured
wave = Wave(torch.randn(2, 44100), fs=44100)
_ = wave | delay # Pipeline operator triggers configuration
assert delay.fs == 44100 # Automatically configured
assert delay.delay_samples == 11025 # Calculated from BPM
Key characteristics:
Tests component interactions
Validates pipeline behavior
Tests automatic configuration
May be slower than unit tests
Parametrized Tests#
Parametrized tests run the same test logic with multiple input values, reducing code duplication.
Example: Invalid Parameter Tests
import pytest
@pytest.mark.parametrize("delay", [0, -1])
def test_reverb_invalid_delay(delay):
"""Test that Reverb rejects invalid delay values."""
with pytest.raises(AssertionError):
Reverb(delay=delay, decay=0.5, mix=0.5)
Benefits:
Reduces code duplication
Tests multiple edge cases
Clear test output showing which parameter failed
Easy to add new test cases
Strategy Pattern Tests#
The test suite extensively validates the strategy pattern used in effects. These tests verify that custom strategies can be injected and built-in strategies behave correctly.
Example: Custom Strategy Injection
from torchfx.effect import Normalize, NormalizationStrategy
class DummyStrategy(NormalizationStrategy):
"""Custom strategy that sets all values to peak."""
def __call__(self, waveform, peak):
return waveform * 0 + peak
def test_normalize_custom_strategy():
"""Test that custom normalization strategies work."""
waveform = torch.tensor([0.2, -0.5, 0.4])
norm = Normalize(peak=2.0, strategy=DummyStrategy())
out = norm(waveform)
torch.testing.assert_close(out, torch.full_like(waveform, 2.0))
Key characteristics:
Tests extensibility points
Validates strategy interface
Ensures custom implementations work
Documents strategy usage
Mocking with Monkeypatch#
Some tests use pytest’s monkeypatch fixture to replace external dependencies.
Example: Mocking torchaudio.functional.gain
def test_gain_db(monkeypatch):
"""Test that Gain calls torchaudio.functional.gain with correct params."""
waveform = torch.tensor([0.1, -0.2, 0.3])
called = {}
def fake_gain(waveform, gain):
called["args"] = (waveform, gain)
return waveform + gain
monkeypatch.setattr("torchaudio.functional.gain", fake_gain)
gain = Gain(gain=6.0, gain_type="db")
out = gain(waveform)
assert torch.allclose(out, waveform + 6.0)
assert called["args"][1] == 6.0 # Verify gain parameter
Benefits:
Tests internal logic without side effects
Validates parameter passing
Avoids external dependencies
Enables testing of hard-to-test code
Assertion Patterns#
torch.testing.assert_close#
The primary assertion method for tensor comparisons:
torch.testing.assert_close(actual, expected)
Features:
Element-wise closeness checking
Appropriate tolerances for floating-point comparisons
Clear error messages showing differences
Example:
def test_peak_normalization_strategy():
"""Test peak normalization strategy."""
waveform = torch.tensor([0.2, -0.5, 0.4])
strat = PeakNormalizationStrategy()
out = strat(waveform, 2.0)
torch.testing.assert_close(out, waveform / 0.5 * 2.0)
pytest.approx#
For scalar comparisons with tolerance:
assert value.item() == pytest.approx(expected, abs=1e-5)
Example:
def test_delay_basic():
"""Test basic delay functionality."""
waveform = torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0])
delay = Delay(delay_samples=2, feedback=0.0, mix=1.0, taps=1)
out = delay(waveform)
assert out[2].item() == pytest.approx(1.0, abs=1e-5)
torch.allclose#
For boolean comparisons with tolerance:
assert torch.allclose(actual, expected, atol=1e-5)
Example:
def test_delay_mix_zero():
"""Test that mix=0 produces only dry signal."""
waveform = torch.randn(10)
delay = Delay(delay_samples=3, feedback=0.5, mix=0.0)
out = delay(waveform)
# Output beyond original length should be zeros
assert torch.allclose(out[10:], torch.zeros(out.size(0) - 10), atol=1e-5)
pytest.raises#
For validating exception handling:
with pytest.raises(ExceptionType):
# Code that should raise exception
Example:
def test_gain_invalid_gain_type():
"""Test that invalid gain raises ValueError."""
with pytest.raises(ValueError):
Gain(gain=-1.0, gain_type="amplitude")
Advanced usage with match:
def test_delay_lazy_fs_inference_error():
"""Test that missing fs raises helpful error message."""
delay = Delay(bpm=120, delay_time="1/8", feedback=0.3, mix=0.2)
waveform = torch.randn(2, 44100)
with pytest.raises(AssertionError, match="Sample rate \\(fs\\) is required"):
delay(waveform)
Writing New Tests#
Test Structure Template#
Follow this template for new tests:
def test_<component>_<scenario>():
"""Brief description of what this test validates.
Include any important details about edge cases or expected behavior.
"""
# Arrange: Set up test data
waveform = torch.randn(2, 44100)
component = ComponentClass(param1=value1, param2=value2)
# Act: Execute the operation
result = component(waveform)
# Assert: Verify results
torch.testing.assert_close(result, expected_result)
Common Test Patterns#
Testing Effects#
def test_my_effect_basic():
"""Test basic functionality of MyEffect."""
# Create test waveform
waveform = torch.tensor([1.0, 0.5, -0.5, -1.0])
# Create effect
effect = MyEffect(param=value)
# Apply effect
output = effect(waveform)
# Verify output
assert output.shape == waveform.shape
torch.testing.assert_close(output, expected)
Testing Filters#
def test_my_filter_frequency_response():
"""Test filter frequency response characteristics."""
# Create filter
filt = MyFilter(cutoff=1000, fs=44100)
# Create test signal (impulse or sine wave)
impulse = torch.zeros(1000)
impulse[0] = 1.0
# Apply filter
output = filt(impulse)
# Verify characteristics
# (e.g., DC component, high-frequency attenuation)
assert output[0] > 0 # Filter passes signal
Testing Multi-Channel Processing#
def test_effect_multichannel():
"""Test that effect handles multi-channel audio correctly."""
# Create multi-channel waveform (2 channels)
waveform = torch.tensor([
[1.0, 2.0, 3.0, 4.0], # Left channel
[0.5, 1.5, 2.5, 3.5] # Right channel
])
effect = MyEffect(param=value)
output = effect(waveform)
# Verify shape preserved
assert output.shape == waveform.shape
# Verify independent channel processing (if applicable)
# Or verify cross-channel effects (if applicable)
Testing Error Handling#
def test_effect_invalid_parameter():
"""Test that invalid parameters raise appropriate errors."""
# Test invalid value
with pytest.raises(ValueError, match="Parameter must be positive"):
MyEffect(param=-1.0)
# Test invalid type
with pytest.raises(TypeError):
MyEffect(param="invalid")
# Test assertion errors
with pytest.raises(AssertionError):
MyEffect(param=0.0) # Zero not allowed
Test Fixtures#
Use pytest fixtures for reusable test data:
import pytest
@pytest.fixture
def stereo_waveform():
"""Fixture providing a standard stereo waveform."""
return torch.randn(2, 44100)
@pytest.fixture
def mono_waveform():
"""Fixture providing a standard mono waveform."""
return torch.randn(1, 44100)
def test_with_fixture(stereo_waveform):
"""Test using stereo waveform fixture."""
effect = MyEffect()
output = effect(stereo_waveform)
assert output.shape == stereo_waveform.shape
Coverage Requirements#
Coverage Targets#
The TorchFX project aims for comprehensive test coverage:
Line coverage: >90% for all modules
Branch coverage: >85% for control flow
Strategy coverage: 100% for all built-in strategies
Coverage Analysis#
Generate coverage reports to identify untested code:
# Terminal report with missing lines
pytest tests/ --cov=src/torchfx --cov-report=term-missing
# HTML report for detailed analysis
pytest tests/ --cov=src/torchfx --cov-report=html
Interpreting coverage reports:
Name Stmts Miss Branch BrPart Cover Missing
--------------------------------------------------------------------------
src/torchfx/__init__.py 8 0 0 0 100%
src/torchfx/effect.py 250 15 60 8 92% 125-130, 245
src/torchfx/filter/__base.py 80 5 20 2 91% 45-47
src/torchfx/wave.py 120 8 30 3 90% 88-92
--------------------------------------------------------------------------
TOTAL 458 28 110 13 91%
Stmts: Total number of statements
Miss: Number of statements not executed
Branch: Total number of branches
BrPart: Number of partially covered branches
Cover: Overall coverage percentage
Missing: Line numbers not covered
Common Test Patterns#
Testing BPM Synchronization#
def test_delay_bpm_synced():
"""Test BPM-synchronized delay calculation."""
waveform = torch.randn(2, 44100)
delay = Delay(bpm=120, delay_time="1/8", fs=44100, feedback=0.3, mix=0.2)
# 120 BPM = 0.5 seconds per beat
# 1/8 note = 0.25 seconds = 11025 samples at 44.1kHz
assert delay.delay_samples == 11025
Testing Strategy Pattern Extensibility#
def test_custom_strategy():
"""Test that custom strategies can be implemented."""
class CustomStrategy(BaseStrategy):
def apply(self, waveform, **kwargs):
return waveform * 2
effect = MyEffect(strategy=CustomStrategy())
waveform = torch.ones(5)
output = effect(waveform)
assert torch.allclose(output, torch.ones(5) * 2)
Testing Lazy Initialization#
def test_lazy_fs_inference():
"""Test lazy sample rate inference with Wave."""
effect = MyEffect(param=value) # No fs provided
assert effect.fs is None
wave = Wave(torch.randn(2, 44100), fs=44100)
result = wave | effect
assert effect.fs == 44100 # Automatically configured
CI Testing#
Tests run automatically in continuous integration via GitHub Actions:
# .github/workflows/ci.yml
- name: Run tests
run: |
pytest tests/ --cov=src/torchfx --cov-report=xml
- name: Upload coverage
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
CI test matrix:
Python versions: 3.10, 3.11, 3.12, 3.13
Operating systems: Ubuntu, macOS, Windows
PyTorch versions: Latest stable
Best Practices#
Write Descriptive Test Names#
# ✅ GOOD: Descriptive, explains what is tested
def test_gain_amplitude_mode_doubles_waveform():
pass
# ❌ BAD: Vague, unclear what is tested
def test_gain():
pass
Use Meaningful Assertions#
# ✅ GOOD: Clear assertion with context
def test_normalize_peak():
waveform = torch.tensor([0.5, -1.0, 0.75])
norm = Normalize(peak=0.8)
out = norm(waveform)
# Max absolute value should equal target peak
assert torch.max(torch.abs(out)).item() == pytest.approx(0.8)
# ❌ BAD: Unclear what is being tested
def test_normalize():
out = Normalize(peak=0.8)(torch.randn(10))
assert out is not None
Test Edge Cases#
def test_delay_zero_feedback():
"""Test delay with zero feedback produces single echo."""
delay = Delay(delay_samples=100, feedback=0.0, taps=5)
# Should only produce first tap, rest should be silent
pass
def test_delay_max_feedback():
"""Test delay with maximum allowed feedback."""
delay = Delay(delay_samples=100, feedback=0.95) # Max allowed
# Should produce long decay
pass
def test_normalize_already_normalized():
"""Test normalizing already-normalized audio."""
waveform = torch.tensor([0.5, -0.5]) # Already peak=0.5
norm = Normalize(peak=0.5)
out = norm(waveform)
torch.testing.assert_close(out, waveform) # Should be unchanged
Keep Tests Independent#
# ✅ GOOD: Each test is self-contained
def test_gain_1():
waveform = torch.ones(5)
gain = Gain(gain=2.0)
assert torch.allclose(gain(waveform), torch.ones(5) * 2.0)
def test_gain_2():
waveform = torch.zeros(5)
gain = Gain(gain=3.0)
assert torch.allclose(gain(waveform), torch.zeros(5))
# ❌ BAD: Tests depend on order
global_waveform = None
def test_setup():
global global_waveform
global_waveform = torch.ones(5)
def test_gain(): # Depends on test_setup
gain = Gain(gain=2.0)
assert torch.allclose(gain(global_waveform), torch.ones(5) * 2.0)