diff --git a/docs/supported_metrics.md b/docs/supported_metrics.md index a15425b..b41d6af 100644 --- a/docs/supported_metrics.md +++ b/docs/supported_metrics.md @@ -50,7 +50,7 @@ We include x mark if the metric is auto-installed in versa. | 43 | x | Qwen2 Recording Environment - Background | qwen2_speech_background_environment_metric | qwen2_speech_background_environment_metric | [Qwen2 Audio](https://github.com/QwenLM/Qwen2-Audio) | [paper](https://arxiv.org/abs/2407.10759) | | 44 | x | Qwen2 Recording Environment - Quality | qwen2_recording_quality_metric | qwen2_recording_quality_metric | [Qwen2 Audio](https://github.com/QwenLM/Qwen2-Audio) | [paper](https://arxiv.org/abs/2407.10759) | | 45 | x | Qwen2 Recording Environment - Channel Type | qwen2_channel_type_metric | qwen2_channel_type_metric | [Qwen2 Audio](https://github.com/QwenLM/Qwen2-Audio) | [paper](https://arxiv.org/abs/2407.10759) | -| 46 | x | Dimensional Emotion | w2v2_dimensional_emotion | w2v2_dimensional_emotion | [w2v2-how-to](https://github.com/audeering/w2v2-how-to) | [paper](https://arxiv.org/pdf/2203.07378) | +| 46 | x | Dimensional Emotion | emo_vad | arousal_emo_vad, valence_emo_vad, dominance_emo_vad | [w2v2-how-to](https://github.com/audeering/w2v2-how-to) | [paper](https://arxiv.org/pdf/2203.07378) | | 47 | x | Uni-VERSA (Versatile Speech Assessment with a Unified Framework) | universa | universa_{sub_metrics} | [Uni-VERSA](https://huggingface.co/collections/espnet/universa-6834e7c0a28225bffb6e2526) | [paper](https://arxiv.org/abs/2505.20741) | | 48 | x | DNSMOS Pro: A Reduced-Size DNN for Probabilistic MOS of Speech | pseudo_mos | dnsmos_pro_bvcc | [DNSMOSPro](https://github.com/fcumlin/DNSMOSPro/tree/main) | [paper](https://www.isca-archive.org/interspeech_2024/cumlin24_interspeech.html) | | 49 | x | DNSMOS Pro: A Reduced-Size DNN for Probabilistic MOS of Speech | pseudo_mos | dnsmos_pro_nisqa | [DNSMOSPro](https://github.com/fcumlin/DNSMOSPro/tree/main) | [paper](https://www.isca-archive.org/interspeech_2024/cumlin24_interspeech.html) | diff --git a/egs/separate_metrics/cdpam_distance.yaml b/egs/separate_metrics/cdpam_distance.yaml new file mode 100644 index 0000000..a73d5fd --- /dev/null +++ b/egs/separate_metrics/cdpam_distance.yaml @@ -0,0 +1,5 @@ +# CDPAM distance metrics +# CDPAM distance between audio samples +# More info in https://github.com/facebookresearch/audiocraft +# -- cdpam_distance: the CDPAM distance between audio samples +- name: cdpam_distance \ No newline at end of file diff --git a/egs/separate_metrics/chroma_alignment.yaml b/egs/separate_metrics/chroma_alignment.yaml new file mode 100644 index 0000000..858b08b --- /dev/null +++ b/egs/separate_metrics/chroma_alignment.yaml @@ -0,0 +1,20 @@ +# Chroma Alignment related metrics +# Chroma-based distance estimation with dynamic programming alignment +# Uses librosa chroma features (STFT, CQT, CENS) with DTW alignment +# -- chroma_stft_cosine_dtw: STFT chroma features with cosine distance and DTW +# -- chroma_stft_euclidean_dtw: STFT chroma features with euclidean distance and DTW +# -- chroma_cqt_cosine_dtw: CQT chroma features with cosine distance and DTW +# -- chroma_cqt_euclidean_dtw: CQT chroma features with euclidean distance and DTW +# -- chroma_cens_cosine_dtw: CENS chroma features with cosine distance and DTW +# -- chroma_cens_euclidean_dtw: CENS chroma features with euclidean distance and DTW +# -- chroma_stft_cosine_dtw_raw: Raw DTW distance with higher scaling +# -- chroma_stft_cosine_dtw_log: Log-scaled DTW distance +- name: chroma_alignment + sample_rate: 22050 + feature_types: ["stft", "cqt", "cens"] + distance_metrics: ["cosine", "euclidean"] + scale_factor: 100.0 + normalize: True + normalize_by_path: True + return_alignment: False + chroma_kwargs: {} \ No newline at end of file diff --git a/egs/separate_metrics/dpam_distance.yaml b/egs/separate_metrics/dpam_distance.yaml new file mode 100644 index 0000000..ba195df --- /dev/null +++ b/egs/separate_metrics/dpam_distance.yaml @@ -0,0 +1,5 @@ +# DPAM distance metrics +# DPAM distance between audio samples +# More info in https://github.com/adrienchaton/PerceptualAudio_Pytorch +# -- dpam_distance: the DPAM distance between audio samples +- name: dpam_distance \ No newline at end of file diff --git a/egs/separate_metrics/emo_vad.yaml b/egs/separate_metrics/emo_vad.yaml new file mode 100644 index 0000000..579043d --- /dev/null +++ b/egs/separate_metrics/emo_vad.yaml @@ -0,0 +1,7 @@ +# EmoVad related metrics +# Dimensional emotion prediction (arousal, valence, dominance) using w2v2-how-to +# More info in https://github.com/audeering/w2v2-how-to +# -- arousal_emo_vad: the dimensional emotion prediction with w2v2 +# -- valence_emo_vad: the dimensional emotion prediction with w2v2 +# -- dominance_emo_vad: the dimensional emotion prediction with w2v2 +- name: emo_vad \ No newline at end of file diff --git a/egs/separate_metrics/lid.yaml b/egs/separate_metrics/lid.yaml index 750e07a..00c18da 100644 --- a/egs/separate_metrics/lid.yaml +++ b/egs/separate_metrics/lid.yaml @@ -1,10 +1,10 @@ - -# Word error rate with ESPnet-OWSM model +# Language Identification with ESPnet-OWSM model # More model_tag can be from the ESPnet huggingface https://huggingface.co/espnet . # The default model is `espnet/owsm_v3.1_ebf`. -# --lid: the nbest language tag +# --language: the nbest language tag - name: lid model_tag: default nbest: 5 + use_gpu: false diff --git a/egs/separate_metrics/nisqa.yaml b/egs/separate_metrics/nisqa.yaml index 67ee222..f2f90c6 100644 --- a/egs/separate_metrics/nisqa.yaml +++ b/egs/separate_metrics/nisqa.yaml @@ -3,8 +3,9 @@ # -- nisqa_noi_pred: NISQA noise prediction # -- nisqa_dis_pred: NISQA distortion prediction # -- nisqa_col_pred: NISQA color prediction -# --nisqa_loud_pred: NISQA loudness prediction +# -- nisqa_loud_pred: NISQA loudness prediction # NOTE(jiatong): pretrain model can be downloaded with `./tools/setup_nisqa.sh` - name: nisqa nisqa_model_path: ./tools/NISQA/weights/nisqa.tar + use_gpu: false diff --git a/egs/separate_metrics/nomad.yaml b/egs/separate_metrics/nomad.yaml index 49cc4cb..167fdf3 100644 --- a/egs/separate_metrics/nomad.yaml +++ b/egs/separate_metrics/nomad.yaml @@ -2,3 +2,4 @@ # -- nomad: nomad reference-based model - name: nomad model_cache: versa_cache/nomad_pt-models + use_gpu: false diff --git a/egs/separate_metrics/noresqa.yaml b/egs/separate_metrics/noresqa.yaml index 07db66f..e61da95 100644 --- a/egs/separate_metrics/noresqa.yaml +++ b/egs/separate_metrics/noresqa.yaml @@ -1,4 +1,8 @@ # noresqa related metrics -# -- noresqa: non-matching reference based speech quality assessment -- name: noresqa - metric_type: 1 #0: NORESQA-score, 1: NORESQA-MOS \ No newline at end of file +# -- noresqa_mos: NORESQA-MOS (metric_type=1) +# -- noresqa_score: NORESQA-score (metric_type=0) +- name: noresqa_mos + metric_type: 1 # 0: NORESQA-score, 1: NORESQA-MOS + model_tag: default + cache_dir: versa_cache/noresqa_model + use_gpu: false \ No newline at end of file diff --git a/egs/separate_metrics/pesq.yaml b/egs/separate_metrics/pesq.yaml new file mode 100644 index 0000000..a94a401 --- /dev/null +++ b/egs/separate_metrics/pesq.yaml @@ -0,0 +1,11 @@ +# PESQ: Perceptual Evaluation of Speech Quality +# https://www.itu.int/rec/T-REC-P.862 +# +# PESQ is a reference-based metric that measures speech quality +# by comparing a degraded signal to a reference signal. +# +# Supported sample rates: +# - 8kHz: narrowband (nb) mode +# - 16kHz: wideband (wb) mode +# - Other rates: automatically resampled to nearest supported rate +- name: pesq \ No newline at end of file diff --git a/egs/separate_metrics/w2v2_dimensional_emotion.yaml b/egs/separate_metrics/w2v2_dimensional_emotion.yaml deleted file mode 100644 index ec12464..0000000 --- a/egs/separate_metrics/w2v2_dimensional_emotion.yaml +++ /dev/null @@ -1,5 +0,0 @@ -# Dimensional emotion prediction calculated based on w2v2 -# More info in https://github.com/audeering/w2v2-how-to - -# --w2v2_dimensional_emotion: the dimensional emotion prediction with w2v2 -- name: w2v2_dimensional_emotion diff --git a/setup.py b/setup.py index 92a2c68..2a8bd69 100644 --- a/setup.py +++ b/setup.py @@ -1,62 +1,108 @@ from setuptools import setup, find_packages +import os + + +# Read README for long description +def read_readme(): + readme_path = os.path.join(os.path.dirname(__file__), "README.md") + if os.path.exists(readme_path): + with open(readme_path, "r", encoding="utf-8") as f: + return f.read() + return "A package for versatile evaluation of speech and audio" + setup( name="versa-speech-audio-toolkit", version="1.0.0", + author="Jiatong Shi", + author_email="ftshijt@gmail.com", + description="A package for versatile evaluation of speech and audio", + long_description=read_readme(), + long_description_content_type="text/markdown", + url="https://github.com/wavlab-speech/versa.git", packages=find_packages(), + python_requires=">=3.8", + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Multimedia :: Sound/Audio :: Analysis", + ], + keywords=["speech", "audio", "metrics", "evaluation", "machine learning"], install_requires=[ + # Core ML and Deep Learning + "torch", + "torchaudio", + "transformers>=4.36.2", "accelerate", - "audioread", - "ci-sdr", - "Cython", - "Distance", - "editdistance", - "einops", - "espnet @ git+https://github.com/ftshijt/espnet.git@espnet_inference#egg=espnet", - "espnet-tts-frontend", - "fast-bss-eval", - "fastdtw", "huggingface-hub", - "hydra-core", - "idna", - "importlib-metadata", - "kaggle", - "kaldiio", - "lazy_loader", - "Levenshtein", - "librosa", - "mir-eval", - "omegaconf", - "onnxruntime", - # NOTE(jiatong): use the latest commit for python 3.13 - "openai-whisper @ git+https://github.com/openai/whisper.git", + "safetensors", + "tokenizers", + "einops", "opt-einsum", - "pesq", - "protobuf", + # Audio Processing + "librosa", + "soundfile", + "audioread", + "resampy", + "torchlibrosa", + "pyworld", "pysptk", + # Speech and Audio Evaluation Metrics + "pesq", "pystoi", - "python-dateutil", - "pyworld", - "pyyaml", + "mir-eval", + "fast-bss-eval", + "ci-sdr", + "speechmos", + # Text Processing and Distance Metrics + "Levenshtein", + "editdistance", + "Distance", "rapidfuzz", - "resampy", - "safetensors", - "scikit-learn", "sentencepiece", - "setuptools", - "soundfile", - "speechmos", + # Scientific Computing + "scikit-learn", "sympy", "threadpoolctl", - "tokenizers", - "torch", - "torch-complex", - "torchaudio", - "torchlibrosa", - "s3prl @ git+https://github.com/ftshijt/s3prl.git@numpy2#egg=s3prl", - "transformers>=4.36.2", + # Configuration and Utilities + "hydra-core", + "omegaconf", + "pyyaml", + "protobuf", + "python-dateutil", + "lazy_loader", + # Build and Compatibility + "Cython", + "setuptools", + "importlib-metadata", + "idna", + # Optional/External Services + "kaggle", + "kaldiio", + "fastdtw", + "onnxruntime", + # Git Dependencies - Speech/Audio Frameworks + "espnet @ git+https://github.com/ftshijt/espnet.git@espnet_inference#egg=espnet", + "espnet-tts-frontend", "espnet_model_zoo", + "s3prl", + # Git Dependencies - Audio Models + # NOTE: Using latest commit for Python 3.13 compatibility + "openai-whisper @ git+https://github.com/openai/whisper.git", + # Git Dependencies - Evaluation Metrics "discrete-speech-metrics @ git+https://github.com/ftshijt/DiscreteSpeechMetrics.git@v1.0.2", + # Additional Dependencies + "torch-complex", "cdpam", ], extras_require={ @@ -65,6 +111,18 @@ "pytest-cov>=2.10.0", "black>=22.3.0", "flake8>=4.0.0", + "isort>=5.0.0", + "mypy>=0.900", + ], + "docs": [ + "sphinx>=4.0.0", + "sphinx-rtd-theme>=1.0.0", + "myst-parser>=0.17.0", + ], + "jupyter": [ + "jupyter>=1.0.0", + "ipykernel>=6.0.0", + "matplotlib>=3.3.0", ], }, entry_points={ @@ -72,9 +130,6 @@ "versa-score=versa.bin.scorer:main", ], }, - author="Jiatong Shi", - author_email="ftshijt@gmail.com", - description="A package for versatile evaluation of speech and audio", - url="https://github.com/shinjiwlab/versa.git", - keywords="speech metrics", + include_package_data=True, + zip_safe=False, ) diff --git a/test/test_metrics/test_asvspoof.py b/test/test_metrics/test_asvspoof.py new file mode 100644 index 0000000..5c9551c --- /dev/null +++ b/test/test_metrics/test_asvspoof.py @@ -0,0 +1,175 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest +import torch + +from versa.utterance_metrics.asvspoof_score import ASVSpoofMetric, is_aasist_available + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif(not is_aasist_available(), reason="AASIST not available") +@pytest.mark.parametrize( + "model_tag,use_gpu", + [ + ("default", False), + ], +) +def test_utterance_asvspoof(model_tag, use_gpu, fixed_audio): + """ + Test the ASVspoof metric using the fixed audio. + The test uses deterministic data so that the result is always reproducible. + """ + config = {"model_tag": model_tag, "use_gpu": use_gpu} + + metric = ASVSpoofMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, metadata=metadata) + + asvspoof_score = result["asvspoof_score"] + + # Check that the score is a valid probability (between 0 and 1) + assert ( + 0.0 <= asvspoof_score <= 1.0 + ), f"ASVspoof score {asvspoof_score} is not between 0 and 1" + + # Check that the result contains the expected key + assert "asvspoof_score" in result, "Result should contain 'asvspoof_score' key" + + +@pytest.mark.skipif(not is_aasist_available(), reason="AASIST not available") +def test_asvspoof_metric_metadata(): + """Test that the ASVspoof metric has correct metadata.""" + config = {"use_gpu": False} + metric = ASVSpoofMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "asvspoof" + assert metadata.category.value == "independent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is False + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert "torch" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +@pytest.mark.skipif(not is_aasist_available(), reason="AASIST not available") +def test_asvspoof_metric_resampling(): + """Test that the ASVspoof metric handles different sample rates correctly.""" + config = {"use_gpu": False} + metric = ASVSpoofMetric(config) + + # Test with 44.1kHz audio (should be resampled to 16kHz) + audio_44k = np.random.random(44100) + metadata_44k = {"sample_rate": 44100} + result_44k = metric.compute(audio_44k, metadata=metadata_44k) + + # Test with 16kHz audio (no resampling needed) + audio_16k = np.random.random(16000) + metadata_16k = {"sample_rate": 16000} + result_16k = metric.compute(audio_16k, metadata=metadata_16k) + + # Both should return valid scores + assert 0.0 <= result_44k["asvspoof_score"] <= 1.0 + assert 0.0 <= result_16k["asvspoof_score"] <= 1.0 + + +@pytest.mark.skipif(not is_aasist_available(), reason="AASIST not available") +def test_asvspoof_metric_invalid_input(): + """Test that the ASVspoof metric handles invalid inputs correctly.""" + config = {"use_gpu": False} + metric = ASVSpoofMetric(config) + + # Test with None input + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, metadata={"sample_rate": 16000}) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() diff --git a/test/test_metrics/test_audiobox_aesthetics.py b/test/test_metrics/test_audiobox_aesthetics.py new file mode 100644 index 0000000..ed9e8e0 --- /dev/null +++ b/test/test_metrics/test_audiobox_aesthetics.py @@ -0,0 +1,237 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest + +from versa.utterance_metrics.audiobox_aesthetics_score import ( + AudioBoxAestheticsMetric, + is_audiobox_aesthetics_available, +) + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif( + not is_audiobox_aesthetics_available(), reason="AudioBox Aesthetics not available" +) +@pytest.mark.parametrize( + "batch_size,precision,use_gpu", + [ + (1, "bf16", False), + (2, "fp32", False), + ], +) +def test_utterance_audiobox_aesthetics(batch_size, precision, use_gpu, fixed_audio): + """ + Test the AudioBox Aesthetics metric using the fixed audio. + The test uses deterministic data so that the result is always reproducible. + """ + config = {"batch_size": batch_size, "precision": precision, "use_gpu": use_gpu} + + metric = AudioBoxAestheticsMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, metadata=metadata) + + # Check that the result contains the expected keys + expected_keys = [ + "audiobox_aesthetics_CE", + "audiobox_aesthetics_CU", + "audiobox_aesthetics_PC", + "audiobox_aesthetics_PQ", + ] + + for key in expected_keys: + assert key in result, f"Result should contain '{key}' key" + assert isinstance(result[key], (int, float)), f"Score {key} should be numeric" + + # Check that all scores are reasonable (not negative for these metrics) + for key in expected_keys: + assert result[key] >= 0, f"Score {key} should be non-negative" + + +@pytest.mark.skipif( + not is_audiobox_aesthetics_available(), reason="AudioBox Aesthetics not available" +) +def test_audiobox_aesthetics_metric_metadata(): + """Test that the AudioBox Aesthetics metric has correct metadata.""" + config = {"use_gpu": False} + metric = AudioBoxAestheticsMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "audiobox_aesthetics" + assert metadata.category.value == "independent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is False + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert "audiobox_aesthetics" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +@pytest.mark.skipif( + not is_audiobox_aesthetics_available(), reason="AudioBox Aesthetics not available" +) +def test_audiobox_aesthetics_metric_different_sample_rates(): + """Test that the AudioBox Aesthetics metric handles different sample rates correctly.""" + config = {"use_gpu": False} + metric = AudioBoxAestheticsMetric(config) + + # Test with 44.1kHz audio + audio_44k = np.random.random(44100) + metadata_44k = {"sample_rate": 44100} + result_44k = metric.compute(audio_44k, metadata=metadata_44k) + + # Test with 16kHz audio + audio_16k = np.random.random(16000) + metadata_16k = {"sample_rate": 16000} + result_16k = metric.compute(audio_16k, metadata=metadata_16k) + + # Both should return valid scores with expected keys + expected_keys = [ + "audiobox_aesthetics_CE", + "audiobox_aesthetics_CU", + "audiobox_aesthetics_PC", + "audiobox_aesthetics_PQ", + ] + + for key in expected_keys: + assert key in result_44k, f"44kHz result should contain '{key}' key" + assert key in result_16k, f"16kHz result should contain '{key}' key" + + +@pytest.mark.skipif( + not is_audiobox_aesthetics_available(), reason="AudioBox Aesthetics not available" +) +def test_audiobox_aesthetics_metric_invalid_input(): + """Test that the AudioBox Aesthetics metric handles invalid inputs correctly.""" + config = {"use_gpu": False} + metric = AudioBoxAestheticsMetric(config) + + # Test with None input + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif( + not is_audiobox_aesthetics_available(), reason="AudioBox Aesthetics not available" +) +def test_audiobox_aesthetics_metric_config_options(): + """Test that the AudioBox Aesthetics metric handles different configuration options.""" + # Test with different batch sizes + config_small_batch = {"batch_size": 1, "use_gpu": False} + metric_small = AudioBoxAestheticsMetric(config_small_batch) + + config_large_batch = {"batch_size": 4, "use_gpu": False} + metric_large = AudioBoxAestheticsMetric(config_large_batch) + + # Test with different precision + config_fp32 = {"precision": "fp32", "use_gpu": False} + metric_fp32 = AudioBoxAestheticsMetric(config_fp32) + + # All should work without errors + audio = np.random.random(16000) + metadata = {"sample_rate": 16000} + + result_small = metric_small.compute(audio, metadata=metadata) + result_large = metric_large.compute(audio, metadata=metadata) + result_fp32 = metric_fp32.compute(audio, metadata=metadata) + + # All should return the same structure + expected_keys = [ + "audiobox_aesthetics_CE", + "audiobox_aesthetics_CU", + "audiobox_aesthetics_PC", + "audiobox_aesthetics_PQ", + ] + + for key in expected_keys: + assert key in result_small + assert key in result_large + assert key in result_fp32 + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() diff --git a/test/test_metrics/test_cdpam.py b/test/test_metrics/test_cdpam.py deleted file mode 100755 index 8828dda..0000000 --- a/test/test_metrics/test_cdpam.py +++ /dev/null @@ -1,83 +0,0 @@ -import wave -from pathlib import Path - -import numpy as np -import pytest - -from versa.utterance_metrics.cdpam_distance import cdpam_metric, cdpam_model_setup - -# Assume the fixed WAV file fixtures and helper function are defined as in the ASR matching test. -# For example: - - -def generate_fixed_wav( - filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None -): - """ - Generate a deterministic WAV file with a modulated sine wave. - """ - t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) - if envelope_func is None: - envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) - else: - envelope = envelope_func(t) - audio = envelope * np.sin(2 * np.pi * base_freq * t) - amplitude = np.iinfo(np.int16).max - data = (audio * amplitude).astype(np.int16) - with wave.open(str(filename), "w") as wf: - wf.setnchannels(1) - wf.setsampwidth(2) - wf.setframerate(sample_rate) - wf.writeframes(data.tobytes()) - - -def load_wav_as_array(wav_path, sample_rate=16000): - """ - Load a WAV file and convert it to a NumPy array scaled to [-1, 1]. - """ - with wave.open(str(wav_path), "rb") as wf: - frames = wf.getnframes() - audio_data = wf.readframes(frames) - audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) - return audio_array / np.iinfo(np.int16).max - - -@pytest.fixture(scope="session") -def fixed_audio_wav(tmp_path_factory): - tmp_dir = tmp_path_factory.mktemp("audio_data") - audio_file = tmp_dir / "fixed_audio.wav" - generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) - return audio_file - - -@pytest.fixture(scope="session") -def fixed_ground_truth_wav(tmp_path_factory): - tmp_dir = tmp_path_factory.mktemp("audio_data") - gt_file = tmp_dir / "fixed_ground_truth.wav" - # Use a different base frequency for ground truth (e.g. 300 Hz) to simulate a mismatch. - generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) - return gt_file - - -@pytest.fixture(scope="session") -def fixed_audio(fixed_audio_wav): - return load_wav_as_array(fixed_audio_wav) - - -@pytest.fixture(scope="session") -def fixed_ground_truth(fixed_ground_truth_wav): - return load_wav_as_array(fixed_ground_truth_wav) - - -# ------------------------------- -# CDPAM Metric Definition and Tests -# ------------------------------- -def test_cdpam_metric_identical(fixed_audio): - """ - When comparing an audio signal with itself, the cdpam distance should be 0.0. - """ - model = cdpam_model_setup() - scores = cdpam_metric(model, fixed_audio, fixed_audio, 16000) - assert ( - scores["cdpam_distance"] == 0.0 - ), f"Expected cdpam distance == 0.0 for identical signals, got {scores['cdpam_distance']}" diff --git a/test/test_metrics/test_cdpam_distance.py b/test/test_metrics/test_cdpam_distance.py new file mode 100644 index 0000000..20ae14f --- /dev/null +++ b/test/test_metrics/test_cdpam_distance.py @@ -0,0 +1,271 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest + +from versa.utterance_metrics.cdpam_distance import ( + CdpamDistanceMetric, + is_cdpam_available, +) + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_wav(tmp_path_factory): + """ + Create a ground truth WAV file to be used as reference audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth.wav" + # Use a different base frequency for ground truth (e.g. 300 Hz) to simulate a mismatch. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) + return gt_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_ground_truth(fixed_ground_truth_wav): + """ + Load the ground truth audio file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_wav) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif(not is_cdpam_available(), reason="CDPAM not available") +@pytest.mark.parametrize( + "use_gpu", + [ + False, + ], +) +def test_utterance_cdpam_distance(use_gpu, fixed_audio, fixed_ground_truth): + """ + Test the CDPAM distance metric using the fixed audio files. + The test uses deterministic data so that the result is always reproducible. + """ + config = {"use_gpu": use_gpu} + + metric = CdpamDistanceMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, fixed_ground_truth, metadata=metadata) + + # Check that the result contains the expected key + assert "cdpam_distance" in result, "Result should contain 'cdpam_distance' key" + + # Check that the result is a float + cdpam_dist = result["cdpam_distance"] + assert isinstance(cdpam_dist, float), "cdpam_distance should be a float" + + # Check that the distance score is reasonable (should be non-negative) + assert cdpam_dist >= 0.0, f"CDPAM distance should be non-negative, got {cdpam_dist}" + + +@pytest.mark.skipif(not is_cdpam_available(), reason="CDPAM not available") +def test_cdpam_distance_metric_metadata(): + """Test that the CDPAM distance metric has correct metadata.""" + config = {"use_gpu": False} + metric = CdpamDistanceMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "cdpam_distance" + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is True + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert "cdpam" in metadata.dependencies + assert "torch" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +@pytest.mark.skipif(not is_cdpam_available(), reason="CDPAM not available") +def test_cdpam_distance_metric_different_sample_rates(): + """Test that the CDPAM distance metric handles different sample rates correctly.""" + config = {"use_gpu": False} + metric = CdpamDistanceMetric(config) + + # Test with 44.1kHz audio (should be resampled to 22.05kHz) + audio_44k_1 = np.random.random(44100) + audio_44k_2 = np.random.random(44100) + metadata_44k = {"sample_rate": 44100} + result_44k = metric.compute(audio_44k_1, audio_44k_2, metadata=metadata_44k) + + # Test with 22.05kHz audio (no resampling needed) + audio_22k_1 = np.random.random(22050) + audio_22k_2 = np.random.random(22050) + metadata_22k = {"sample_rate": 22050} + result_22k = metric.compute(audio_22k_1, audio_22k_2, metadata=metadata_22k) + + # Both should return valid scores with expected keys + assert ( + "cdpam_distance" in result_44k + ), "44kHz result should contain 'cdpam_distance' key" + assert ( + "cdpam_distance" in result_22k + ), "22kHz result should contain 'cdpam_distance' key" + + # Both should return float values + assert isinstance(result_44k["cdpam_distance"], float) + assert isinstance(result_22k["cdpam_distance"], float) + + +@pytest.mark.skipif(not is_cdpam_available(), reason="CDPAM not available") +def test_cdpam_distance_metric_invalid_input(): + """Test that the CDPAM distance metric handles invalid inputs correctly.""" + config = {"use_gpu": False} + metric = CdpamDistanceMetric(config) + + # Test with None predictions + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, np.random.random(22050), metadata={"sample_rate": 22050}) + + # Test with None references + with pytest.raises(ValueError, match="Reference signal must be provided"): + metric.compute(np.random.random(22050), None, metadata={"sample_rate": 22050}) + + +@pytest.mark.skipif(not is_cdpam_available(), reason="CDPAM not available") +def test_cdpam_distance_metric_config_options(): + """Test that the CDPAM distance metric handles different configuration options.""" + # Test with GPU disabled + config_cpu = {"use_gpu": False} + metric_cpu = CdpamDistanceMetric(config_cpu) + + # All should work without errors + audio1 = np.random.random(22050) + audio2 = np.random.random(22050) + metadata = {"sample_rate": 22050} + + result_cpu = metric_cpu.compute(audio1, audio2, metadata=metadata) + + # Should return the same structure + assert "cdpam_distance" in result_cpu + assert isinstance(result_cpu["cdpam_distance"], float) + + +@pytest.mark.skipif(not is_cdpam_available(), reason="CDPAM not available") +def test_cdpam_distance_metric_identical_signals(): + """Test that the CDPAM distance metric gives zero distance for identical signals.""" + config = {"use_gpu": False} + metric = CdpamDistanceMetric(config) + metadata = {"sample_rate": 22050} + + # Test with identical signals + audio = np.random.random(22050) + result = metric.compute(audio, audio, metadata=metadata) + + # Results should be 0.0 for identical signals + assert ( + result["cdpam_distance"] == 0.0 + ), "Identical signals should have zero distance" + + +@pytest.mark.skipif(not is_cdpam_available(), reason="CDPAM not available") +def test_cdpam_distance_metric_consistent_results(): + """Test that the CDPAM distance metric gives consistent results for the same inputs.""" + config = {"use_gpu": False} + metric = CdpamDistanceMetric(config) + metadata = {"sample_rate": 22050} + + # Test with fixed signals + audio1 = np.random.random(22050) + audio2 = np.random.random(22050) + result1 = metric.compute(audio1, audio2, metadata=metadata) + result2 = metric.compute(audio1, audio2, metadata=metadata) + + # Results should be identical for the same inputs + np.testing.assert_almost_equal( + result1["cdpam_distance"], + result2["cdpam_distance"], + decimal=6, + err_msg="Results should be identical for the same inputs", + ) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav, fixed_ground_truth_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_ground_truth_wav).exists() diff --git a/test/test_metrics/test_chroma_alignment.py b/test/test_metrics/test_chroma_alignment.py new file mode 100644 index 0000000..9a2cff6 --- /dev/null +++ b/test/test_metrics/test_chroma_alignment.py @@ -0,0 +1,327 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest + +from versa.utterance_metrics.chroma_alignment import ChromaAlignmentMetric + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=22050, base_freq=440, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 440 Hz sine wave (A4 note). + generate_fixed_wav(audio_file, duration=1.0, sample_rate=22050, base_freq=440) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as ground truth. + This one uses a different duration but same frequency to test DTW alignment. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth.wav" + # Generate a ground truth file with a 440 Hz sine wave but different duration. + generate_fixed_wav(gt_file, duration=1.2, sample_rate=22050, base_freq=440) + return gt_file + + +@pytest.fixture(scope="session") +def different_pitch_wav(tmp_path_factory): + """ + Create a WAV file with a different pitch for testing distance metrics. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + diff_file = tmp_dir / "different_pitch.wav" + # Generate a file with a 554.37 Hz sine wave (C#5 note). + generate_fixed_wav(diff_file, duration=1.0, sample_rate=22050, base_freq=554.37) + return diff_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=22050): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_ground_truth(fixed_ground_truth_wav): + """ + Load the fixed ground truth file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_wav) + + +@pytest.fixture(scope="session") +def different_pitch_audio(different_pitch_wav): + """ + Load the different pitch audio file as a NumPy array. + """ + return load_wav_as_array(different_pitch_wav) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.parametrize( + "scale_factor,feature_types,distance_metrics", + [ + (100.0, ["stft"], ["cosine"]), + (50.0, ["stft", "cqt"], ["cosine", "euclidean"]), + (200.0, ["stft", "cqt", "cens"], ["cosine"]), + ], +) +def test_utterance_chroma_alignment( + scale_factor, feature_types, distance_metrics, fixed_audio, fixed_ground_truth +): + """ + Test the Chroma Alignment metric using the fixed audio and ground truth. + The test uses deterministic data so that the result is always reproducible. + """ + config = { + "scale_factor": scale_factor, + "feature_types": feature_types, + "distance_metrics": distance_metrics, + "normalize": True, + "normalize_by_path": True, + } + + metric = ChromaAlignmentMetric(config) + metadata = {"sample_rate": 22050} + result = metric.compute(fixed_audio, fixed_ground_truth, metadata=metadata) + + # Check that the result contains the expected keys + for feat_type in feature_types: + for dist_metric in distance_metrics: + key = f"chroma_{feat_type}_{dist_metric}_dtw" + assert key in result, f"Result should contain '{key}' key" + assert isinstance( + result[key], (int, float) + ), f"Score {key} should be numeric" + assert result[key] >= 0, f"Score {key} should be non-negative" + + # Check for additional scaled variants + if "stft" in feature_types and "cosine" in distance_metrics: + assert "chroma_stft_cosine_dtw_raw" in result + assert "chroma_stft_cosine_dtw_log" in result + + +def test_chroma_alignment_metric_metadata(): + """Test that the Chroma Alignment metric has correct metadata.""" + config = {"scale_factor": 100.0} + metric = ChromaAlignmentMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "chroma_alignment" + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is True + assert metadata.requires_text is False + assert metadata.gpu_compatible is False + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + assert "scipy" in metadata.dependencies + + +def test_chroma_alignment_metric_different_pitches(fixed_audio, different_pitch_audio): + """Test that the Chroma Alignment metric gives higher distances for different pitches.""" + config = {"scale_factor": 100.0} + metric = ChromaAlignmentMetric(config) + metadata = {"sample_rate": 22050} + + # Test with same pitch (should give lower distance) + result_same = metric.compute(fixed_audio, fixed_audio, metadata=metadata) + + # Test with different pitch (should give higher distance) + result_different = metric.compute( + fixed_audio, different_pitch_audio, metadata=metadata + ) + + # The distance should be higher for different pitches + for key in result_same: + if key in result_different and not key.endswith("_log"): + # Log-scaled metric works differently, so skip it + assert ( + result_different[key] >= result_same[key] + ), f"Distance should be higher for different pitches in {key}" + + +def test_chroma_alignment_metric_invalid_input(): + """Test that the Chroma Alignment metric handles invalid inputs correctly.""" + config = {"scale_factor": 100.0} + metric = ChromaAlignmentMetric(config) + + # Test with None input + with pytest.raises( + ValueError, match="Both predicted and ground truth signals must be provided" + ): + metric.compute(None, np.random.random(22050), metadata={"sample_rate": 22050}) + + with pytest.raises( + ValueError, match="Both predicted and ground truth signals must be provided" + ): + metric.compute(np.random.random(22050), None, metadata={"sample_rate": 22050}) + + +def test_chroma_alignment_metric_config_options(): + """Test that the Chroma Alignment metric handles different configuration options.""" + # Test with different scale factors + config_small_scale = { + "scale_factor": 50.0, + "feature_types": ["stft"], + "distance_metrics": ["cosine"], + } + metric_small = ChromaAlignmentMetric(config_small_scale) + + config_large_scale = { + "scale_factor": 200.0, + "feature_types": ["stft"], + "distance_metrics": ["cosine"], + } + metric_large = ChromaAlignmentMetric(config_large_scale) + + # Test with normalization options + config_no_norm = { + "normalize": False, + "feature_types": ["stft"], + "distance_metrics": ["cosine"], + } + metric_no_norm = ChromaAlignmentMetric(config_no_norm) + + # All should work without errors + audio = np.sin(2 * np.pi * 440 * np.linspace(0, 1, 22050)) + audio2 = np.sin(2 * np.pi * 880 * np.linspace(0, 1, 22050)) + metadata = {"sample_rate": 22050} + result_small = metric_small.compute(audio, audio2, metadata=metadata) + result_large = metric_large.compute(audio, audio2, metadata=metadata) + result_no_norm = metric_no_norm.compute(audio, audio2, metadata=metadata) + + # All should return the same structure + assert "chroma_stft_cosine_dtw" in result_small + assert "chroma_stft_cosine_dtw" in result_large + assert "chroma_stft_cosine_dtw" in result_no_norm + + # Scale factor should affect the magnitude + assert ( + result_large["chroma_stft_cosine_dtw"] > result_small["chroma_stft_cosine_dtw"] + ) + + +def test_chroma_alignment_metric_alignment_paths(): + """Test that the Chroma Alignment metric can return alignment paths when requested.""" + config = { + "scale_factor": 100.0, + "feature_types": ["stft"], + "distance_metrics": ["cosine"], + "return_alignment": True, + } + + metric = ChromaAlignmentMetric(config) + metadata = {"sample_rate": 22050} + audio = np.random.random(22050) + + result = metric.compute(audio, audio, metadata=metadata) + + # Should contain alignments when requested + assert "alignments" in result + assert "chroma_stft_cosine_dtw" in result["alignments"] + + +def test_chroma_alignment_metric_multidimensional_input(): + """Test that the Chroma Alignment metric handles multidimensional input correctly.""" + config = { + "scale_factor": 100.0, + "feature_types": ["stft"], + "distance_metrics": ["cosine"], + } + metric = ChromaAlignmentMetric(config) + metadata = {"sample_rate": 22050} + + # Test with 2D input (should be flattened) + audio_2d = np.random.random((22050, 1)) + result_2d = metric.compute(audio_2d, audio_2d, metadata=metadata) + + # Test with 1D input + audio_1d = np.random.random(22050) + result_1d = metric.compute(audio_1d, audio_1d, metadata=metadata) + + # Both should work and give similar results (not exactly the same due to randomness) + assert "chroma_stft_cosine_dtw" in result_2d + assert "chroma_stft_cosine_dtw" in result_1d + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist( + fixed_audio_wav, fixed_ground_truth_wav, different_pitch_wav +): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_ground_truth_wav).exists() + assert Path(different_pitch_wav).exists() diff --git a/test/test_metrics/test_discrete_speech.py b/test/test_metrics/test_discrete_speech.py index f439ed3..027f08e 100644 --- a/test/test_metrics/test_discrete_speech.py +++ b/test/test_metrics/test_discrete_speech.py @@ -5,137 +5,304 @@ import pytest from versa.utterance_metrics.discrete_speech import ( - discrete_speech_setup, - discrete_speech_metric, + DiscreteSpeechMetric, + is_discrete_speech_available, ) -# Reuse the same helper functions from your STOI test +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- def generate_fixed_wav( filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None ): """ Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. """ t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. if envelope_func is None: envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) else: envelope = envelope_func(t) audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. amplitude = np.iinfo(np.int16).max data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. with wave.open(str(filename), "w") as wf: - wf.setnchannels(1) - wf.setsampwidth(2) + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. wf.setframerate(sample_rate) wf.writeframes(data.tobytes()) -def load_wav_as_array(wav_path, sample_rate=16000): - """ - Load a WAV file and convert it to a NumPy array scaled to [-1, 1]. - """ - with wave.open(str(wav_path), "rb") as wf: - frames = wf.getnframes() - audio_data = wf.readframes(frames) - audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) - return audio_array / np.iinfo(np.int16).max - - +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- @pytest.fixture(scope="session") def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ tmp_dir = tmp_path_factory.mktemp("audio_data") audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) return audio_file @pytest.fixture(scope="session") def fixed_ground_truth_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as ground truth. + This one uses a different base frequency (e.g., 300 Hz) so that the test + intentionally simulates a mismatch. + """ tmp_dir = tmp_path_factory.mktemp("audio_data") gt_file = tmp_dir / "fixed_ground_truth.wav" - # Use a different base frequency for ground truth (e.g. 300 Hz) to simulate a mismatch. + # Generate a ground truth file with a 300 Hz sine wave. generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) return gt_file +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + @pytest.fixture(scope="session") def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ return load_wav_as_array(fixed_audio_wav) @pytest.fixture(scope="session") def fixed_ground_truth(fixed_ground_truth_wav): + """ + Load the fixed ground truth file as a NumPy array. + """ return load_wav_as_array(fixed_ground_truth_wav) -@pytest.fixture(scope="session") -def discrete_speech_predictors(): - """Set up discrete speech predictors once per test session.""" - return discrete_speech_setup(use_gpu=False) - - # ------------------------------- -# Discrete Speech Metric Tests +# Test Functions # ------------------------------- -def test_discrete_speech_metric_identical(fixed_audio, discrete_speech_predictors): +@pytest.mark.skipif( + not is_discrete_speech_available(), reason="Discrete Speech Metrics not available" +) +@pytest.mark.parametrize( + "use_gpu", + [ + False, + ], +) +def test_utterance_discrete_speech_identical(use_gpu, fixed_audio): """ + Test the Discrete Speech metric using identical audio signals. When comparing an audio signal with itself, the discrete speech scores should be high. """ - scores = discrete_speech_metric( - discrete_speech_predictors, fixed_audio, fixed_audio, 16000 - ) + config = {"use_gpu": use_gpu} + + metric = DiscreteSpeechMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, fixed_audio, metadata=metadata) # Check that all expected metrics are present - assert "speech_bert" in scores - assert "speech_bleu" in scores - assert "speech_token_distance" in scores + assert "speech_bert" in result, "Result should contain 'speech_bert' key" + assert "speech_bleu" in result, "Result should contain 'speech_bleu' key" + assert ( + "speech_token_distance" in result + ), "Result should contain 'speech_token_distance' key" # For identical signals, scores should be relatively high # Note: Perfect scores (1.0) are not always expected for discrete speech metrics assert ( - scores["speech_bert"] > 0.9 - ), f"Expected SpeechBERT score > 0.5 for identical signals, got {scores['speech_bert']}" + result["speech_bert"] > 0.9 + ), f"Expected SpeechBERT score > 0.9 for identical signals, got {result['speech_bert']}" assert ( - scores["speech_bleu"] > 0.9 - ), f"Expected SpeechBLEU score > 0.3 for identical signals, got {scores['speech_bleu']}" + result["speech_bleu"] > 0.9 + ), f"Expected SpeechBLEU score > 0.9 for identical signals, got {result['speech_bleu']}" assert ( - scores["speech_token_distance"] > 0.9 - ), f"Expected SpeechTokenDistance score > 0.3 for identical signals, got {scores['speech_token_distance']}" + result["speech_token_distance"] > 0.9 + ), f"Expected SpeechTokenDistance score > 0.9 for identical signals, got {result['speech_token_distance']}" -def test_discrete_speech_metric_different( - fixed_audio, fixed_ground_truth, discrete_speech_predictors -): +@pytest.mark.skipif( + not is_discrete_speech_available(), reason="Discrete Speech Metrics not available" +) +@pytest.mark.parametrize( + "use_gpu", + [ + False, + ], +) +def test_utterance_discrete_speech_different(use_gpu, fixed_audio, fixed_ground_truth): """ + Test the Discrete Speech metric using different audio signals. When comparing two different fixed signals, the discrete speech scores should be lower than identical signals. """ + config = {"use_gpu": use_gpu} + + metric = DiscreteSpeechMetric(config) + metadata = {"sample_rate": 16000} + # Get scores for identical signals first - identical_scores = discrete_speech_metric( - discrete_speech_predictors, fixed_audio, fixed_audio, 16000 - ) + identical_result = metric.compute(fixed_audio, fixed_audio, metadata=metadata) # Get scores for different signals - different_scores = discrete_speech_metric( - discrete_speech_predictors, fixed_audio, fixed_ground_truth, 16000 + different_result = metric.compute( + fixed_audio, fixed_ground_truth, metadata=metadata ) # Check that all expected metrics are present - assert "speech_bert" in different_scores - assert "speech_bleu" in different_scores - assert "speech_token_distance" in different_scores + assert "speech_bert" in different_result, "Result should contain 'speech_bert' key" + assert "speech_bleu" in different_result, "Result should contain 'speech_bleu' key" + assert ( + "speech_token_distance" in different_result + ), "Result should contain 'speech_token_distance' key" # Different signals should have lower scores than identical signals assert ( - different_scores["speech_bert"] <= identical_scores["speech_bert"] - ), f"Expected SpeechBERT score for different signals ({different_scores['speech_bert']}) to be <= identical signals ({identical_scores['speech_bert']})" - + different_result["speech_bert"] <= identical_result["speech_bert"] + ), f"Expected SpeechBERT score for different signals ({different_result['speech_bert']}) to be <= identical signals ({identical_result['speech_bert']})" assert ( - different_scores["speech_bleu"] <= identical_scores["speech_bleu"] - ), f"Expected SpeechBLEU score for different signals ({different_scores['speech_bleu']}) to be <= identical signals ({identical_scores['speech_bleu']})" - + different_result["speech_bleu"] <= identical_result["speech_bleu"] + ), f"Expected SpeechBLEU score for different signals ({different_result['speech_bleu']}) to be <= identical signals ({identical_result['speech_bleu']})" assert ( - different_scores["speech_token_distance"] - <= identical_scores["speech_token_distance"] - ), f"Expected SpeechTokenDistance score for different signals ({different_scores['speech_token_distance']}) to be <= identical signals ({identical_scores['speech_token_distance']})" + different_result["speech_token_distance"] + <= identical_result["speech_token_distance"] + ), f"Expected SpeechTokenDistance score for different signals ({different_result['speech_token_distance']}) to be <= identical signals ({identical_result['speech_token_distance']})" + + +@pytest.mark.skipif( + not is_discrete_speech_available(), reason="Discrete Speech Metrics not available" +) +def test_discrete_speech_metric_metadata(): + """Test that the Discrete Speech metric has correct metadata.""" + config = {"use_gpu": False} + metric = DiscreteSpeechMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "discrete_speech" + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is True + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert "discrete_speech_metrics" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +@pytest.mark.skipif( + not is_discrete_speech_available(), reason="Discrete Speech Metrics not available" +) +def test_discrete_speech_metric_different_sample_rates(): + """Test that the Discrete Speech metric handles different sample rates correctly.""" + config = {"use_gpu": False} + metric = DiscreteSpeechMetric(config) + + # Test with 44.1kHz audio (should be resampled to 16kHz) + audio_44k = np.random.random(44100) + metadata_44k = {"sample_rate": 44100} + result_44k = metric.compute(audio_44k, audio_44k, metadata=metadata_44k) + + # Test with 16kHz audio (no resampling needed) + audio_16k = np.random.random(16000) + metadata_16k = {"sample_rate": 16000} + result_16k = metric.compute(audio_16k, audio_16k, metadata=metadata_16k) + + # Both should return valid scores with expected keys + expected_keys = ["speech_bert", "speech_bleu", "speech_token_distance"] + + for key in expected_keys: + assert key in result_44k, f"44kHz result should contain '{key}' key" + assert key in result_16k, f"16kHz result should contain '{key}' key" + assert isinstance( + result_44k[key], (int, float) + ), f"Score {key} should be numeric" + assert isinstance( + result_16k[key], (int, float) + ), f"Score {key} should be numeric" + + +@pytest.mark.skipif( + not is_discrete_speech_available(), reason="Discrete Speech Metrics not available" +) +def test_discrete_speech_metric_invalid_input(): + """Test that the Discrete Speech metric handles invalid inputs correctly.""" + config = {"use_gpu": False} + metric = DiscreteSpeechMetric(config) + + # Test with None input + with pytest.raises( + ValueError, match="Both predicted and ground truth signals must be provided" + ): + metric.compute(None, np.random.random(16000), metadata={"sample_rate": 16000}) + + with pytest.raises( + ValueError, match="Both predicted and ground truth signals must be provided" + ): + metric.compute(np.random.random(16000), None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif( + not is_discrete_speech_available(), reason="Discrete Speech Metrics not available" +) +def test_discrete_speech_metric_config_options(): + """Test that the Discrete Speech metric handles different configuration options.""" + # Test with GPU disabled + config_cpu = {"use_gpu": False} + metric_cpu = DiscreteSpeechMetric(config_cpu) + + # Test with different sample rate + config_custom_sr = {"use_gpu": False, "sample_rate": 22050} + metric_custom_sr = DiscreteSpeechMetric(config_custom_sr) + + # All should work without errors + audio = np.random.random(16000) + metadata = {"sample_rate": 16000} + + result_cpu = metric_cpu.compute(audio, audio, metadata=metadata) + result_custom_sr = metric_custom_sr.compute(audio, audio, metadata=metadata) + + # All should return the same structure + expected_keys = ["speech_bert", "speech_bleu", "speech_token_distance"] + + for key in expected_keys: + assert key in result_cpu + assert key in result_custom_sr + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav, fixed_ground_truth_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_ground_truth_wav).exists() diff --git a/test/test_metrics/test_dpam.py b/test/test_metrics/test_dpam.py deleted file mode 100755 index e9557bc..0000000 --- a/test/test_metrics/test_dpam.py +++ /dev/null @@ -1,83 +0,0 @@ -import wave -from pathlib import Path - -import numpy as np -import pytest - -from versa.utterance_metrics.dpam_distance import dpam_metric, dpam_model_setup - -# Assume the fixed WAV file fixtures and helper function are defined as in the ASR matching test. -# For example: - - -def generate_fixed_wav( - filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None -): - """ - Generate a deterministic WAV file with a modulated sine wave. - """ - t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) - if envelope_func is None: - envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) - else: - envelope = envelope_func(t) - audio = envelope * np.sin(2 * np.pi * base_freq * t) - amplitude = np.iinfo(np.int16).max - data = (audio * amplitude).astype(np.int16) - with wave.open(str(filename), "w") as wf: - wf.setnchannels(1) - wf.setsampwidth(2) - wf.setframerate(sample_rate) - wf.writeframes(data.tobytes()) - - -def load_wav_as_array(wav_path, sample_rate=16000): - """ - Load a WAV file and convert it to a NumPy array scaled to [-1, 1]. - """ - with wave.open(str(wav_path), "rb") as wf: - frames = wf.getnframes() - audio_data = wf.readframes(frames) - audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) - return audio_array / np.iinfo(np.int16).max - - -@pytest.fixture(scope="session") -def fixed_audio_wav(tmp_path_factory): - tmp_dir = tmp_path_factory.mktemp("audio_data") - audio_file = tmp_dir / "fixed_audio.wav" - generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) - return audio_file - - -@pytest.fixture(scope="session") -def fixed_ground_truth_wav(tmp_path_factory): - tmp_dir = tmp_path_factory.mktemp("audio_data") - gt_file = tmp_dir / "fixed_ground_truth.wav" - # Use a different base frequency for ground truth (e.g. 300 Hz) to simulate a mismatch. - generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) - return gt_file - - -@pytest.fixture(scope="session") -def fixed_audio(fixed_audio_wav): - return load_wav_as_array(fixed_audio_wav) - - -@pytest.fixture(scope="session") -def fixed_ground_truth(fixed_ground_truth_wav): - return load_wav_as_array(fixed_ground_truth_wav) - - -# ------------------------------- -# DPAM Metric Definition and Tests -# ------------------------------- -def test_dpam_metric_identical(fixed_audio): - """ - When comparing an audio signal with itself, the dpam distance should be 0.0. - """ - model = dpam_model_setup() - scores = dpam_metric(model, fixed_audio, fixed_audio, 16000) - assert ( - scores["dpam_distance"] == 0.0 - ), f"Expected dpam distance == 0.0 for identical signals, got {scores['dpam_distance']}" diff --git a/test/test_metrics/test_dpam_distance.py b/test/test_metrics/test_dpam_distance.py new file mode 100644 index 0000000..b07d936 --- /dev/null +++ b/test/test_metrics/test_dpam_distance.py @@ -0,0 +1,266 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest + +from versa.utterance_metrics.dpam_distance import DpamDistanceMetric + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_wav(tmp_path_factory): + """ + Create a ground truth WAV file to be used as reference audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth.wav" + # Use a different base frequency for ground truth (e.g. 300 Hz) to simulate a mismatch. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) + return gt_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_ground_truth(fixed_ground_truth_wav): + """ + Load the ground truth audio file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_wav) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.parametrize( + "use_gpu", + [ + False, + ], +) +def test_utterance_dpam_distance(use_gpu, fixed_audio, fixed_ground_truth): + """ + Test the DPAM distance metric using the fixed audio files. + The test uses deterministic data so that the result is always reproducible. + """ + config = {"use_gpu": use_gpu} + + metric = DpamDistanceMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, fixed_ground_truth, metadata=metadata) + + # Check that the result contains the expected key + assert "dpam_distance" in result, "Result should contain 'dpam_distance' key" + + # Check that the result is a float + dpam_dist = result["dpam_distance"] + assert isinstance(dpam_dist, float), "dpam_distance should be a float" + + # Check that the distance score is reasonable (should be non-negative) + assert dpam_dist >= 0.0, f"DPAM distance should be non-negative, got {dpam_dist}" + + +def test_dpam_distance_metric_metadata(): + """Test that the DPAM distance metric has correct metadata.""" + config = {"use_gpu": False} + metric = DpamDistanceMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "dpam_distance" + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is True + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert "torch" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + assert "filelock" in metadata.dependencies + + +def test_dpam_distance_metric_different_sample_rates(): + """Test that the DPAM distance metric handles different sample rates correctly.""" + config = {"use_gpu": False} + metric = DpamDistanceMetric(config) + + # Test with 44.1kHz audio (should be resampled to 22.05kHz) + audio_44k_1 = np.random.random(44100) + audio_44k_2 = np.random.random(44100) + metadata_44k = {"sample_rate": 44100} + result_44k = metric.compute(audio_44k_1, audio_44k_2, metadata=metadata_44k) + + # Test with 22.05kHz audio (no resampling needed) + audio_22k_1 = np.random.random(22050) + audio_22k_2 = np.random.random(22050) + metadata_22k = {"sample_rate": 22050} + result_22k = metric.compute(audio_22k_1, audio_22k_2, metadata=metadata_22k) + + # Both should return valid scores with expected keys + assert ( + "dpam_distance" in result_44k + ), "44kHz result should contain 'dpam_distance' key" + assert ( + "dpam_distance" in result_22k + ), "22kHz result should contain 'dpam_distance' key" + + # Both should return float values + assert isinstance(result_44k["dpam_distance"], float) + assert isinstance(result_22k["dpam_distance"], float) + + +def test_dpam_distance_metric_invalid_input(): + """Test that the DPAM distance metric handles invalid inputs correctly.""" + config = {"use_gpu": False} + metric = DpamDistanceMetric(config) + + # Test with None predictions + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, np.random.random(22050), metadata={"sample_rate": 22050}) + + # Test with None references + with pytest.raises(ValueError, match="Reference signal must be provided"): + metric.compute(np.random.random(22050), None, metadata={"sample_rate": 22050}) + + +def test_dpam_distance_metric_config_options(): + """Test that the DPAM distance metric handles different configuration options.""" + # Test with GPU disabled + config_cpu = {"use_gpu": False} + metric_cpu = DpamDistanceMetric(config_cpu) + + # Test with custom cache directory + config_custom_cache = {"use_gpu": False, "cache_dir": "custom_cache"} + metric_custom_cache = DpamDistanceMetric(config_custom_cache) + + # All should work without errors + audio1 = np.random.random(22050) + audio2 = np.random.random(22050) + metadata = {"sample_rate": 22050} + + result_cpu = metric_cpu.compute(audio1, audio2, metadata=metadata) + result_custom_cache = metric_custom_cache.compute(audio1, audio2, metadata=metadata) + + # All should return the same structure + assert "dpam_distance" in result_cpu + assert "dpam_distance" in result_custom_cache + assert isinstance(result_cpu["dpam_distance"], float) + assert isinstance(result_custom_cache["dpam_distance"], float) + + +def test_dpam_distance_metric_identical_signals(): + """Test that the DPAM distance metric gives zero distance for identical signals.""" + config = {"use_gpu": False} + metric = DpamDistanceMetric(config) + metadata = {"sample_rate": 22050} + + # Test with identical signals + audio = np.random.random(22050) + result = metric.compute(audio, audio, metadata=metadata) + + # Results should be 0.0 for identical signals + assert result["dpam_distance"] == 0.0, "Identical signals should have zero distance" + + +def test_dpam_distance_metric_consistent_results(): + """Test that the DPAM distance metric gives consistent results for the same inputs.""" + config = {"use_gpu": False} + metric = DpamDistanceMetric(config) + metadata = {"sample_rate": 22050} + + # Test with fixed signals + audio1 = np.random.random(22050) + audio2 = np.random.random(22050) + result1 = metric.compute(audio1, audio2, metadata=metadata) + result2 = metric.compute(audio1, audio2, metadata=metadata) + + # Results should be identical for the same inputs + np.testing.assert_almost_equal( + result1["dpam_distance"], + result2["dpam_distance"], + decimal=6, + err_msg="Results should be identical for the same inputs", + ) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav, fixed_ground_truth_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_ground_truth_wav).exists() diff --git a/test/test_metrics/test_emo_similarity.py b/test/test_metrics/test_emo_similarity.py new file mode 100644 index 0000000..c781e96 --- /dev/null +++ b/test/test_metrics/test_emo_similarity.py @@ -0,0 +1,278 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest + +from versa.utterance_metrics.emo_similarity import Emo2vecMetric, is_emo2vec_available + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_audio_wav_2(tmp_path_factory): + """ + Create a second fixed WAV file to be used as reference audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio_2.wav" + # Generate an audio file with a 200 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=200) + return audio_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_audio_2(fixed_audio_wav_2): + """ + Load the second fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav_2) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif(not is_emo2vec_available(), reason="Emo2vec not available") +@pytest.mark.parametrize( + "use_gpu", + [ + False, + ], +) +def test_utterance_emotion(use_gpu, fixed_audio, fixed_audio_2): + """ + Test the Emotion metric using the fixed audio files. + The test uses deterministic data so that the result is always reproducible. + """ + config = {"use_gpu": use_gpu} + + metric = Emo2vecMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, fixed_audio_2, metadata=metadata) + + # Check that the result contains the expected key + assert ( + "emotion_similarity" in result + ), "Result should contain 'emotion_similarity' key" + + # Check that the result is a float + emotion_sim = result["emotion_similarity"] + assert isinstance(emotion_sim, float), "emotion_similarity should be a float" + + # Check that the similarity score is reasonable (between -1 and 1 for cosine similarity) + assert ( + -1.0 <= emotion_sim <= 1.0 + ), f"Emotion similarity should be between -1 and 1, got {emotion_sim}" + + +@pytest.mark.skipif(not is_emo2vec_available(), reason="Emo2vec not available") +def test_emotion_metric_metadata(): + """Test that the Emotion metric has correct metadata.""" + config = {"use_gpu": False} + metric = Emo2vecMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "emotion" + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is True + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert "emo2vec_versa" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +@pytest.mark.skipif(not is_emo2vec_available(), reason="Emo2vec not available") +def test_emotion_metric_different_sample_rates(): + """Test that the Emotion metric handles different sample rates correctly.""" + config = {"use_gpu": False} + metric = Emo2vecMetric(config) + + # Test with 44.1kHz audio (should be resampled to 16kHz) + audio_44k_1 = np.random.random(44100) + audio_44k_2 = np.random.random(44100) + metadata_44k = {"sample_rate": 44100} + result_44k = metric.compute(audio_44k_1, audio_44k_2, metadata=metadata_44k) + + # Test with 16kHz audio (no resampling needed) + audio_16k_1 = np.random.random(16000) + audio_16k_2 = np.random.random(16000) + metadata_16k = {"sample_rate": 16000} + result_16k = metric.compute(audio_16k_1, audio_16k_2, metadata=metadata_16k) + + # Both should return valid scores with expected keys + assert ( + "emotion_similarity" in result_44k + ), "44kHz result should contain 'emotion_similarity' key" + assert ( + "emotion_similarity" in result_16k + ), "16kHz result should contain 'emotion_similarity' key" + + # Both should return float values + assert isinstance(result_44k["emotion_similarity"], float) + assert isinstance(result_16k["emotion_similarity"], float) + + +@pytest.mark.skipif(not is_emo2vec_available(), reason="Emo2vec not available") +def test_emotion_metric_invalid_input(): + """Test that the Emotion metric handles invalid inputs correctly.""" + config = {"use_gpu": False} + metric = Emo2vecMetric(config) + + # Test with None predictions + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, np.random.random(16000), metadata={"sample_rate": 16000}) + + # Test with None references + with pytest.raises(ValueError, match="Reference signal must be provided"): + metric.compute(np.random.random(16000), None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif(not is_emo2vec_available(), reason="Emo2vec not available") +def test_emotion_metric_config_options(): + """Test that the Emotion metric handles different configuration options.""" + # Test with GPU disabled + config_cpu = {"use_gpu": False} + metric_cpu = Emo2vecMetric(config_cpu) + + # Test with different model tag + config_custom_model = {"use_gpu": False, "model_tag": "base"} + metric_custom_model = Emo2vecMetric(config_custom_model) + + # All should work without errors + audio1 = np.random.random(16000) + audio2 = np.random.random(16000) + metadata = {"sample_rate": 16000} + + result_cpu = metric_cpu.compute(audio1, audio2, metadata=metadata) + result_custom_model = metric_custom_model.compute(audio1, audio2, metadata=metadata) + + # All should return the same structure + assert "emotion_similarity" in result_cpu + assert "emotion_similarity" in result_custom_model + assert isinstance(result_cpu["emotion_similarity"], float) + assert isinstance(result_custom_model["emotion_similarity"], float) + + +@pytest.mark.skipif(not is_emo2vec_available(), reason="Emo2vec not available") +def test_emotion_metric_identical_signals(): + """Test that the Emotion metric gives high similarity for identical signals.""" + config = {"use_gpu": False} + metric = Emo2vecMetric(config) + metadata = {"sample_rate": 16000} + + # Test with identical signals + audio = np.random.random(16000) + result = metric.compute(audio, audio, metadata=metadata) + + # Results should be very close to 1.0 for identical signals + assert ( + result["emotion_similarity"] > 0.99 + ), "Identical signals should have very high similarity" + + +@pytest.mark.skipif(not is_emo2vec_available(), reason="Emo2vec not available") +def test_emotion_metric_consistent_results(): + """Test that the Emotion metric gives consistent results for the same inputs.""" + config = {"use_gpu": False} + metric = Emo2vecMetric(config) + metadata = {"sample_rate": 16000} + + # Test with fixed signals + audio1 = np.random.random(16000) + audio2 = np.random.random(16000) + result1 = metric.compute(audio1, audio2, metadata=metadata) + result2 = metric.compute(audio1, audio2, metadata=metadata) + + # Results should be identical for the same inputs + np.testing.assert_almost_equal( + result1["emotion_similarity"], + result2["emotion_similarity"], + decimal=6, + err_msg="Results should be identical for the same inputs", + ) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav, fixed_audio_wav_2): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_audio_wav_2).exists() diff --git a/test/test_metrics/test_emo_vad.py b/test/test_metrics/test_emo_vad.py index 6ffee75..80ac31b 100644 --- a/test/test_metrics/test_emo_vad.py +++ b/test/test_metrics/test_emo_vad.py @@ -4,68 +4,296 @@ import numpy as np import pytest -from versa.utterance_metrics.emo_vad import dim_emo_pred, w2v2_emo_dim_setup - -# Assume the fixed WAV file fixtures and helper function are defined as in the dimentionsal emotion prediction test. -# For example: +from versa.utterance_metrics.emo_vad import EmoVadMetric, is_transformers_available +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- def generate_fixed_wav( filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None ): """ Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. """ t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. if envelope_func is None: envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) else: envelope = envelope_func(t) audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. amplitude = np.iinfo(np.int16).max data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. with wave.open(str(filename), "w") as wf: - wf.setnchannels(1) - wf.setsampwidth(2) + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. wf.setframerate(sample_rate) wf.writeframes(data.tobytes()) +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- def load_wav_as_array(wav_path, sample_rate=16000): """ - Load a WAV file and convert it to a NumPy array scaled to [-1, 1]. + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. """ with wave.open(str(wav_path), "rb") as wf: frames = wf.getnframes() audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) return audio_array / np.iinfo(np.int16).max -@pytest.fixture(scope="session") -def fixed_audio_wav(tmp_path_factory): - tmp_dir = tmp_path_factory.mktemp("audio_data") - audio_file = tmp_dir / "fixed_audio.wav" - generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) - return audio_file - - @pytest.fixture(scope="session") def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ return load_wav_as_array(fixed_audio_wav) # ------------------------------- -# emo_vad Metric Definition and Tests +# Test Functions +# ------------------------------- +@pytest.mark.skipif( + not is_transformers_available(), reason="Transformers not available" +) +@pytest.mark.parametrize( + "use_gpu", + [ + False, + ], +) +def test_utterance_emo_vad(use_gpu, fixed_audio): + """ + Test the EmoVad metric using the fixed audio. + The test uses deterministic data so that the result is always reproducible. + """ + config = {"use_gpu": use_gpu} + + metric = EmoVadMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, metadata=metadata) + + # Check that the result contains the expected key + assert "arousal_emo_vad" in result, "Result should contain 'arousal_emo_vad' key" + assert "valence_emo_vad" in result, "Result should contain 'valence_emo_vad' key" + assert ( + "dominance_emo_vad" in result + ), "Result should contain 'dominance_emo_vad' key" + + # Check that the result is a numpy array with 3 values (arousal, valence, dominance) + arousal = result["arousal_emo_vad"] + valence = result["valence_emo_vad"] + dominance = result["dominance_emo_vad"] + assert isinstance(arousal, float), "arousal_emo_vad should be a float" + assert isinstance(valence, float), "valence_emo_vad should be a float" + assert isinstance(dominance, float), "dominance_emo_vad should be a float" + + # Check that all values are numeric and reasonable (emotion scores are typically between 0 and 1) + assert ( + 0.0 <= arousal <= 1.0 + ), f"Arousal score should be between 0 and 1, got {arousal}" + assert ( + 0.0 <= valence <= 1.0 + ), f"Valence score should be between 0 and 1, got {valence}" + assert ( + 0.0 <= dominance <= 1.0 + ), f"Dominance score should be between 0 and 1, got {dominance}" + + +@pytest.mark.skipif( + not is_transformers_available(), reason="Transformers not available" +) +def test_emo_vad_metric_metadata(): + """Test that the EmoVad metric has correct metadata.""" + config = {"use_gpu": False} + metric = EmoVadMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "emo_vad" + assert metadata.category.value == "independent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is False + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert "transformers" in metadata.dependencies + assert "torch" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +@pytest.mark.skipif( + not is_transformers_available(), reason="Transformers not available" +) +def test_emo_vad_metric_different_sample_rates(): + """Test that the EmoVad metric handles different sample rates correctly.""" + config = {"use_gpu": False} + metric = EmoVadMetric(config) + + # Test with 44.1kHz audio (should be resampled to 16kHz) + audio_44k = np.random.random(44100) + metadata_44k = {"sample_rate": 44100} + result_44k = metric.compute(audio_44k, metadata=metadata_44k) + + # Test with 16kHz audio (no resampling needed) + audio_16k = np.random.random(16000) + metadata_16k = {"sample_rate": 16000} + result_16k = metric.compute(audio_16k, metadata=metadata_16k) + + # Both should return valid scores with expected keys + assert ( + "arousal_emo_vad" in result_44k + ), "44kHz result should contain 'arousal_emo_vad' key" + assert ( + "valence_emo_vad" in result_44k + ), "44kHz result should contain 'valence_emo_vad' key" + assert ( + "dominance_emo_vad" in result_44k + ), "44kHz result should contain 'dominance_emo_vad' key" + assert ( + "arousal_emo_vad" in result_16k + ), "16kHz result should contain 'arousal_emo_vad' key" + assert ( + "valence_emo_vad" in result_16k + ), "16kHz result should contain 'valence_emo_vad' key" + assert ( + "dominance_emo_vad" in result_16k + ), "16kHz result should contain 'dominance_emo_vad' key" + + # Both should return numpy arrays with 3 values + assert ( + type(result_44k["arousal_emo_vad"]) == float + ), "arousal_emo_vad should be a float" + assert ( + type(result_44k["valence_emo_vad"]) == float + ), "valence_emo_vad should be a float" + assert ( + type(result_44k["dominance_emo_vad"]) == float + ), "dominance_emo_vad should be a float" + assert ( + type(result_16k["arousal_emo_vad"]) == float + ), "arousal_emo_vad should be a float" + assert ( + type(result_16k["valence_emo_vad"]) == float + ), "valence_emo_vad should be a float" + assert ( + type(result_16k["dominance_emo_vad"]) == float + ), "dominance_emo_vad should be a float" + + +@pytest.mark.skipif( + not is_transformers_available(), reason="Transformers not available" +) +def test_emo_vad_metric_invalid_input(): + """Test that the EmoVad metric handles invalid inputs correctly.""" + config = {"use_gpu": False} + metric = EmoVadMetric(config) + + # Test with None input + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif( + not is_transformers_available(), reason="Transformers not available" +) +def test_emo_vad_metric_config_options(): + """Test that the EmoVad metric handles different configuration options.""" + # Test with GPU disabled + config_cpu = {"use_gpu": False} + metric_cpu = EmoVadMetric(config_cpu) + + # Test with different model tag + config_custom_model = { + "use_gpu": False, + "model_tag": "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim", + } + metric_custom_model = EmoVadMetric(config_custom_model) + + # All should work without errors + audio = np.random.random(16000) + metadata = {"sample_rate": 16000} + + result_cpu = metric_cpu.compute(audio, metadata=metadata) + result_custom_model = metric_custom_model.compute(audio, metadata=metadata) + + # All should return the same structure + assert "arousal_emo_vad" in result_cpu + assert "valence_emo_vad" in result_cpu + assert "dominance_emo_vad" in result_cpu + assert "arousal_emo_vad" in result_custom_model + assert "valence_emo_vad" in result_custom_model + assert "dominance_emo_vad" in result_custom_model + assert ( + type(result_cpu["arousal_emo_vad"]) == float + ), "arousal_emo_vad should be a float" + assert ( + type(result_cpu["valence_emo_vad"]) == float + ), "valence_emo_vad should be a float" + assert ( + type(result_cpu["dominance_emo_vad"]) == float + ), "dominance_emo_vad should be a float" + + +@pytest.mark.skipif( + not is_transformers_available(), reason="Transformers not available" +) +def test_emo_vad_metric_identical_signals(): + """Test that the EmoVad metric gives consistent results for identical signals.""" + config = {"use_gpu": False} + metric = EmoVadMetric(config) + metadata = {"sample_rate": 16000} + + # Test with identical signals + audio = np.random.random(16000) + result1 = metric.compute(audio, metadata=metadata) + result2 = metric.compute(audio, metadata=metadata) + + # Results should be identical for the same input + np.testing.assert_array_almost_equal( + result1["arousal_emo_vad"], + result2["arousal_emo_vad"], + decimal=6, + err_msg="Results should be identical for the same input", + ) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) # ------------------------------- -def test_emo_vad_metric_identical(fixed_audio): +def test_fixed_wav_files_exist(fixed_audio_wav): """ - When comparing an audio signal with itself, the STOI score should be 1.0. + Verify that the fixed WAV files were created. """ - emo_utils = w2v2_emo_dim_setup() - scores = dim_emo_pred(emo_utils, fixed_audio, 16000) - assert scores["aro_val_dom_emo"] == pytest.approx( - np.array([0.3982302, 0.43092448, 0.41154572], dtype=np.float32), - rel=1e-3, - abs=1e-6, - ), f"Expected aro_val_dom_emo of [0.3982302, 0.43092448, 0.41154572] for identical signals, got {scores['aro_val_dom_emo']}" + assert Path(fixed_audio_wav).exists() diff --git a/test/test_metrics/test_nisqa.py b/test/test_metrics/test_nisqa.py new file mode 100644 index 0000000..ecc9dab --- /dev/null +++ b/test/test_metrics/test_nisqa.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Unit tests for NISQA metric.""" + +import wave +from pathlib import Path +from unittest.mock import Mock, patch + +import numpy as np +import pytest +import torch + +from versa.utterance_metrics.nisqa import ( + NisqaMetric, + nisqa_metric, + nisqa_model_setup, +) + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +# ------------------------------- +# Mock NISQA Model Fixture +# ------------------------------- +@pytest.fixture +def mock_nisqa_model(): + """Create a mock NISQA model for testing.""" + model = Mock() + model.device = "cpu" + model.args = {"model": "NISQA"} + return model + + +# ------------------------------- +# Test NISQA Metric Class +# ------------------------------- +class TestNisqaMetric: + """Test the NisqaMetric class.""" + + def test_initialization_without_model_path(self): + """Test that initialization fails without model path.""" + config = {"use_gpu": False} + with pytest.raises(ValueError, match="NISQA model path must be provided"): + NisqaMetric(config) + + @patch("versa.utterance_metrics.nisqa.torch.load") + @patch("versa.utterance_metrics.nisqa.NL.NISQA") + def test_initialization_success(self, mock_nisqa_class, mock_torch_load): + """Test successful initialization of NisqaMetric.""" + # Mock the checkpoint + mock_checkpoint = { + "args": { + "model": "NISQA", + "ms_seg_length": 15, + "ms_n_mels": 48, + "cnn_model": "resnet", + "cnn_c_out_1": 32, + "cnn_c_out_2": 32, + "cnn_c_out_3": 32, + "cnn_kernel_size": 3, + "cnn_dropout": 0.1, + "cnn_pool_1": 2, + "cnn_pool_2": 2, + "cnn_pool_3": 2, + "cnn_fc_out_h": 128, + "td": "lstm", + "td_sa_d_model": 128, + "td_sa_nhead": 8, + "td_sa_pos_enc": "sin", + "td_sa_num_layers": 2, + "td_sa_h": 128, + "td_sa_dropout": 0.1, + "td_lstm_h": 128, + "td_lstm_num_layers": 2, + "td_lstm_dropout": 0.1, + "td_lstm_bidirectional": True, + "td_2": "lstm", + "td_2_sa_d_model": 128, + "td_2_sa_nhead": 8, + "td_2_sa_pos_enc": "sin", + "td_2_sa_num_layers": 2, + "td_2_sa_h": 128, + "td_2_sa_dropout": 0.1, + "td_2_lstm_h": 128, + "td_2_lstm_num_layers": 2, + "td_2_lstm_dropout": 0.1, + "td_2_lstm_bidirectional": True, + "pool": "att", + "pool_att_h": 128, + "pool_att_dropout": 0.1, + }, + "model_state_dict": {}, + } + mock_torch_load.return_value = mock_checkpoint + + # Mock the NISQA model + mock_model = Mock() + mock_model.load_state_dict.return_value = ([], []) # No missing/unexpected keys + mock_nisqa_class.return_value = mock_model + + config = { + "nisqa_model_path": "./tools/NISQA/weights/nisqa.tar", + "use_gpu": False, + } + + metric = NisqaMetric(config) + assert metric.model is not None + assert metric.model.device == "cpu" + + def test_compute_with_none_predictions(self): + """Test that compute raises error with None predictions.""" + config = { + "nisqa_model_path": "./tools/NISQA/weights/nisqa.tar", + "use_gpu": False, + } + metric = NisqaMetric(config) + + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None) + + @patch("versa.utterance_metrics.nisqa.NL.versa_eval_mos") + def test_compute_success(self, mock_eval_mos, mock_nisqa_model): + """Test successful computation of NISQA scores.""" + # Mock the evaluation function + mock_eval_mos.return_value = { + "mos_pred": [[0.5]], + "noi_pred": [[1.0]], + "dis_pred": [[2.0]], + "col_pred": [[1.5]], + "loud_pred": [[1.2]], + } + + config = { + "nisqa_model_path": "./tools/NISQA/weights/nisqa.tar", + "use_gpu": False, + } + metric = NisqaMetric(config) + metric.model = mock_nisqa_model + + audio = np.random.random(16000) + metadata = {"sample_rate": 16000} + + result = metric.compute(audio, metadata=metadata) + + assert "nisqa_mos_pred" in result + assert "nisqa_noi_pred" in result + assert "nisqa_dis_pred" in result + assert "nisqa_col_pred" in result + assert "nisqa_loud_pred" in result + assert result["nisqa_mos_pred"] == 0.5 + + def test_get_metadata(self): + """Test that get_metadata returns correct metadata.""" + config = { + "nisqa_model_path": "./tools/NISQA/weights/nisqa.tar", + "use_gpu": False, + } + metric = NisqaMetric(config) + + metadata = metric.get_metadata() + assert metadata.name == "nisqa" + assert metadata.category.value == "independent" + assert metadata.metric_type.value == "float" + assert not metadata.requires_reference + assert not metadata.requires_text + assert metadata.gpu_compatible + + +# ------------------------------- +# Integration Tests +# ------------------------------- +@pytest.mark.integration +class TestNisqaIntegration: + """Integration tests for NISQA metric.""" + + @pytest.mark.parametrize( + "sample_rate,use_gpu", + [ + (16000, False), + (22050, False), + (48000, False), + ], + ) + def test_nisqa_with_different_sample_rates(self, sample_rate, use_gpu, fixed_audio): + """Test NISQA with different sample rates.""" + # Skip if NISQA dependencies are not available + try: + import versa.utterance_metrics.nisqa_utils.nisqa_lib + except ImportError: + pytest.skip("NISQA dependencies not available") + + # This test would require a real NISQA model file + # For now, we'll just test the basic structure + config = { + "use_gpu": use_gpu, + } + + # Test that the metric can be instantiated (without actual model loading) + with pytest.raises(ValueError, match="NISQA model path must be provided"): + NisqaMetric(config) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + + +# ------------------------------- +# Test Registration Function +# ------------------------------- +def test_register_nisqa_metric(): + """Test the registration function.""" + from versa.utterance_metrics.nisqa import register_nisqa_metric + + # Mock registry + mock_registry = Mock() + + # Register the metric + register_nisqa_metric(mock_registry) + + # Verify registration was called + mock_registry.register.assert_called_once() + + # Verify the call arguments + call_args = mock_registry.register.call_args + assert call_args[0][0] == NisqaMetric # First argument should be the class + assert call_args[0][1].name == "nisqa" # Second argument should be metadata diff --git a/test/test_metrics/test_nomad.py b/test/test_metrics/test_nomad.py new file mode 100644 index 0000000..2bdded3 --- /dev/null +++ b/test/test_metrics/test_nomad.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Unit tests for NOMAD metric.""" + +import wave +from pathlib import Path +from unittest.mock import Mock, patch + +import numpy as np +import pytest +import torch + +from versa.utterance_metrics.nomad import ( + NomadMetric, + is_nomad_available, +) + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as ground truth. + This one uses a different base frequency (e.g., 300 Hz) so that the test + intentionally simulates a mismatch. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth.wav" + # Generate a ground truth file with a 300 Hz sine wave. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) + return gt_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_ground_truth(fixed_ground_truth_wav): + """ + Load the fixed ground truth file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_wav) + + +# ------------------------------- +# Mock NOMAD Model Fixture +# ------------------------------- +@pytest.fixture +def mock_nomad_model(): + """Create a mock NOMAD model for testing.""" + model = Mock() + model.predict.return_value = 0.5 # Mock prediction value + return model + + +# ------------------------------- +# Test NOMAD Metric Class +# ------------------------------- +class TestNomadMetric: + """Test the NomadMetric class.""" + + def test_initialization_without_nomad(self): + """Test that initialization fails without nomad dependency.""" + with patch("versa.utterance_metrics.nomad.NOMAD_AVAILABLE", False): + config = {"use_gpu": False, "model_cache": "test_cache"} + with pytest.raises(ImportError, match="nomad is not installed"): + NomadMetric(config) + + @patch("versa.utterance_metrics.nomad.Nomad") + def test_initialization_success(self, mock_nomad_class): + """Test successful initialization of NomadMetric.""" + # Mock the NOMAD class + mock_model = Mock() + mock_nomad_class.return_value = mock_model + + config = { + "use_gpu": False, + "model_cache": "test_cache", + } + + metric = NomadMetric(config) + assert metric.model is not None + mock_nomad_class.assert_called_once_with(device="cpu", cache_dir="test_cache") + + def test_compute_with_none_predictions(self): + """Test that compute raises error with None predictions.""" + with patch("versa.utterance_metrics.nomad.Nomad") as mock_nomad_class: + mock_model = Mock() + mock_nomad_class.return_value = mock_model + + config = {"use_gpu": False, "model_cache": "test_cache"} + metric = NomadMetric(config) + + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, np.random.random(16000)) + + def test_compute_with_none_references(self): + """Test that compute raises error with None references.""" + with patch("versa.utterance_metrics.nomad.Nomad") as mock_nomad_class: + mock_model = Mock() + mock_nomad_class.return_value = mock_model + + config = {"use_gpu": False, "model_cache": "test_cache"} + metric = NomadMetric(config) + + with pytest.raises(ValueError, match="Reference signal must be provided"): + metric.compute(np.random.random(16000), None) + + @patch("versa.utterance_metrics.nomad.librosa.resample") + def test_compute_success(self, mock_resample, mock_nomad_model): + """Test successful computation of NOMAD score.""" + # Mock the resample function + mock_resample.side_effect = lambda x, orig_sr, target_sr: x + + config = {"use_gpu": False, "model_cache": "test_cache"} + metric = NomadMetric(config) + metric.model = mock_nomad_model + + audio = np.random.random(16000) + gt_audio = np.random.random(16000) + metadata = {"sample_rate": 16000} + + result = metric.compute(audio, gt_audio, metadata=metadata) + + assert "nomad" in result + assert result["nomad"] == 0.5 + mock_nomad_model.predict.assert_called_once() + + @patch("versa.utterance_metrics.nomad.librosa.resample") + def test_compute_with_resampling(self, mock_resample, mock_nomad_model): + """Test computation with resampling.""" + # Mock the resample function + mock_resample.side_effect = lambda x, orig_sr, target_sr: x + + config = {"use_gpu": False, "model_cache": "test_cache"} + metric = NomadMetric(config) + metric.model = mock_nomad_model + + audio = np.random.random(8000) # Different sample rate + gt_audio = np.random.random(8000) + metadata = {"sample_rate": 8000} + + result = metric.compute(audio, gt_audio, metadata=metadata) + + assert "nomad" in result + # Verify resampling was called + assert mock_resample.call_count == 2 + + def test_get_metadata(self): + """Test that get_metadata returns correct metadata.""" + with patch("versa.utterance_metrics.nomad.Nomad") as mock_nomad_class: + mock_model = Mock() + mock_nomad_class.return_value = mock_model + + config = {"use_gpu": False, "model_cache": "test_cache"} + metric = NomadMetric(config) + + metadata = metric.get_metadata() + assert metadata.name == "nomad" + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference + assert not metadata.requires_text + assert metadata.gpu_compatible + + +# ------------------------------- +# Test Utility Functions +# ------------------------------- +class TestUtilityFunctions: + """Test utility functions.""" + + @patch("versa.utterance_metrics.nomad.NOMAD_AVAILABLE", True) + def test_is_nomad_available_true(self): + """Test is_nomad_available when NOMAD is available.""" + assert is_nomad_available() is True + + @patch("versa.utterance_metrics.nomad.NOMAD_AVAILABLE", False) + def test_is_nomad_available_false(self): + """Test is_nomad_available when NOMAD is not available.""" + assert is_nomad_available() is False + + +# ------------------------------- +# Integration Tests +# ------------------------------- +@pytest.mark.integration +class TestNomadIntegration: + """Integration tests for NOMAD metric.""" + + @pytest.mark.parametrize( + "sample_rate,use_gpu", + [ + (16000, False), + (22050, False), + (48000, False), + ], + ) + def test_nomad_with_different_sample_rates( + self, sample_rate, use_gpu, fixed_audio, fixed_ground_truth + ): + """Test NOMAD with different sample rates.""" + # Skip if NOMAD dependencies are not available + if not is_nomad_available(): + pytest.skip("NOMAD dependencies not available") + + # This test would require a real NOMAD model file + # For now, we'll just test the basic structure + config = { + "use_gpu": use_gpu, + "model_cache": "test_cache", + } + + # Test that the metric can be instantiated (without actual model loading) + with patch("versa.utterance_metrics.nomad.Nomad") as mock_nomad_class: + mock_model = Mock() + mock_nomad_class.return_value = mock_model + + metric = NomadMetric(config) + assert metric.model is not None + + +# ------------------------------- +# Example Test Function Using the Reused WAV Files +# ------------------------------- +@pytest.mark.parametrize( + "use_gpu,cache_dir", + [ + (False, "test_cache"), + (True, "test_cache"), + ], +) +def test_utterance_nomad(use_gpu, cache_dir, fixed_audio, fixed_ground_truth): + """ + Test the NOMAD metric using the fixed audio and ground truth. + The test uses deterministic data so that the result is always reproducible. + """ + with patch("versa.utterance_metrics.nomad.Nomad") as mock_nomad_class: + mock_model = Mock() + mock_model.predict.return_value = 0.5 + mock_nomad_class.return_value = mock_model + + # Use the new class-based API + config = {"use_gpu": use_gpu, "model_cache": cache_dir} + metric = NomadMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, fixed_ground_truth, metadata=metadata) + nomad_score = result["nomad"] + + # We expect the score to be 0.5 based on our mock + assert nomad_score == pytest.approx( + 0.5, rel=1e-3, abs=1e-6 + ), "value from nomad_score {} is mismatch from the defined one {}".format( + nomad_score, 0.5 + ) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav, fixed_ground_truth_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_ground_truth_wav).exists() + + +# ------------------------------- +# Test Registration Function +# ------------------------------- +def test_register_nomad_metric(): + """Test the registration function.""" + from versa.utterance_metrics.nomad import register_nomad_metric + + # Mock registry + mock_registry = Mock() + + # Register the metric + register_nomad_metric(mock_registry) + + # Verify registration was called + mock_registry.register.assert_called_once() + + # Verify the call arguments + call_args = mock_registry.register.call_args + assert call_args[0][0] == NomadMetric # First argument should be the class + assert call_args[0][1].name == "nomad" # Second argument should be metadata diff --git a/test/test_metrics/test_noresqa.py b/test/test_metrics/test_noresqa.py new file mode 100644 index 0000000..959cf95 --- /dev/null +++ b/test/test_metrics/test_noresqa.py @@ -0,0 +1,333 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest +import torch +from packaging.version import parse as V + +from versa.utterance_metrics.noresqa import NoresqaMetric, is_noresqa_available + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as ground truth. + This one uses a different base frequency (e.g., 300 Hz) so that the test + intentionally simulates a mismatch. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth.wav" + # Generate a ground truth file with a 300 Hz sine wave. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) + return gt_file + + +@pytest.fixture(scope="session") +def fixed_audio_8k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 8kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio_8k.wav" + # Generate an audio file with a 150 Hz sine wave at 8kHz. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=8000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_8k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 8kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth_8k.wav" + # Generate a ground truth file with a 300 Hz sine wave at 8kHz. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=8000, base_freq=300) + return gt_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_ground_truth(fixed_ground_truth_wav): + """ + Load the fixed ground truth file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_wav) + + +@pytest.fixture(scope="session") +def fixed_audio_8k(fixed_audio_8k_wav): + """ + Load the fixed 8kHz audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_8k_wav, sample_rate=8000) + + +@pytest.fixture(scope="session") +def fixed_ground_truth_8k(fixed_ground_truth_8k_wav): + """ + Load the fixed 8kHz ground truth file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_8k_wav, sample_rate=8000) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif( + not is_noresqa_available(), + reason="noresqa is not available", +) +@pytest.mark.parametrize( + "metric_type,model_tag,use_gpu", + [ + (1, "default", False), # NORESQA-MOS + (0, "default", False), # NORESQA-score + ], +) +def test_noresqa_metric_basic( + metric_type, model_tag, use_gpu, fixed_audio, fixed_ground_truth +): + """ + Test the NORESQA metric with basic configuration. + """ + config = { + "metric_type": metric_type, + "model_tag": model_tag, + "use_gpu": use_gpu, + "cache_dir": "test_cache/noresqa_model", + } + + metric = NoresqaMetric(config) + result = metric.compute( + fixed_audio, fixed_ground_truth, metadata={"sample_rate": 16000} + ) + + # Check that result contains noresqa_score field + if metric_type == 0: + assert "noresqa_score" in result + assert isinstance(result["noresqa_score"], (int, float, np.number)) + assert not np.isnan(result["noresqa_score"]) + assert not np.isinf(result["noresqa_score"]) + elif metric_type == 1: + assert "noresqa_mos" in result + assert isinstance(result["noresqa_mos"], (int, float, np.number)) + assert not np.isnan(result["noresqa_mos"]) + assert not np.isinf(result["noresqa_mos"]) + + +@pytest.mark.skipif( + not is_noresqa_available(), + reason="noresqa is not available", +) +def test_noresqa_metric_resampling(fixed_audio_8k, fixed_ground_truth_8k): + """ + Test the NORESQA metric with audio that needs resampling. + """ + config = { + "metric_type": 1, # NORESQA-MOS + "model_tag": "default", + "use_gpu": False, + "cache_dir": "test_cache/noresqa_model", + } + + metric = NoresqaMetric(config) + result = metric.compute( + fixed_audio_8k, fixed_ground_truth_8k, metadata={"sample_rate": 8000} + ) + + # Check that result contains noresqa_score field + assert "noresqa_mos" in result + assert isinstance(result["noresqa_mos"], (int, float, np.number)) + assert not np.isnan(result["noresqa_mos"]) + assert not np.isinf(result["noresqa_mos"]) + + +@pytest.mark.skipif( + not is_noresqa_available(), + reason="noresqa is not available", +) +def test_noresqa_metric_invalid_input(): + """ + Test the NORESQA metric with invalid input. + """ + config = { + "metric_type": 1, + "model_tag": "default", + "use_gpu": False, + "cache_dir": "test_cache/noresqa_model", + } + + metric = NoresqaMetric(config) + + # Test with None predictions + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, np.random.random(16000), metadata={"sample_rate": 16000}) + + # Test with None references + with pytest.raises(ValueError, match="Reference signal must be provided"): + metric.compute(np.random.random(16000), None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif( + not is_noresqa_available(), + reason="noresqa is not available", +) +@pytest.mark.parametrize("metric_type", [0, 1]) +def test_noresqa_metric_metadata(metric_type): + """ + Test the NORESQA metric metadata. + """ + config = { + "metric_type": metric_type, + "model_tag": "default", + "use_gpu": False, + "cache_dir": "test_cache/noresqa_model", + } + + metric = NoresqaMetric(config) + metadata = metric.get_metadata() + + expected_name = "noresqa_mos" if metric_type == 1 else "noresqa_score" + assert metadata.name == expected_name + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is True + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert metadata.auto_install is False + assert "fairseq" in metadata.dependencies + assert "torch" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +def test_noresqa_metric_not_available(): + """ + Test the NORESQA metric when noresqa is not available. + """ + # This test should be skipped if noresqa is available + if is_noresqa_available(): + pytest.skip("noresqa is available, skipping this test") + + config = { + "metric_type": 1, + "model_tag": "default", + "use_gpu": False, + "cache_dir": "test_cache/noresqa_model", + } + + with pytest.raises(ImportError, match="noresqa is not installed"): + NoresqaMetric(config) + + +@pytest.mark.skipif( + not is_noresqa_available(), + reason="noresqa is not available", +) +def test_noresqa_metric_invalid_metric_type(): + """ + Test the NORESQA metric with invalid metric_type. + """ + config = { + "metric_type": 2, # Invalid metric type + "model_tag": "default", + "use_gpu": False, + "cache_dir": "test_cache/noresqa_model", + } + + with pytest.raises(RuntimeError, match="Invalid metric_type"): + NoresqaMetric(config) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist( + fixed_audio_wav, + fixed_ground_truth_wav, + fixed_audio_8k_wav, + fixed_ground_truth_8k_wav, +): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_ground_truth_wav).exists() + assert Path(fixed_audio_8k_wav).exists() + assert Path(fixed_ground_truth_8k_wav).exists() diff --git a/test/test_metrics/test_owsm_lid.py b/test/test_metrics/test_owsm_lid.py new file mode 100644 index 0000000..44c0c2c --- /dev/null +++ b/test/test_metrics/test_owsm_lid.py @@ -0,0 +1,240 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest +import torch +from packaging.version import parse as V + +from versa.utterance_metrics.owsm_lid import OwsmLidMetric, is_espnet2_available + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_audio_8k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 8kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio_8k.wav" + # Generate an audio file with a 150 Hz sine wave at 8kHz. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=8000, base_freq=150) + return audio_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_audio_8k(fixed_audio_8k_wav): + """ + Load the fixed 8kHz audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_8k_wav, sample_rate=8000) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif( + not is_espnet2_available(), + reason="espnet2 is not available", +) +@pytest.mark.parametrize( + "model_tag,nbest,use_gpu", + [ + ("default", 3, False), + ("default", 5, False), + ("espnet/owsm_v3.1_ebf", 3, False), + ], +) +def test_owsm_lid_metric_basic(model_tag, nbest, use_gpu, fixed_audio): + """ + Test the OWSM LID metric with basic configuration. + """ + config = { + "model_tag": model_tag, + "nbest": nbest, + "use_gpu": use_gpu, + } + + metric = OwsmLidMetric(config) + result = metric.compute(fixed_audio, metadata={"sample_rate": 16000}) + + # Check that result contains language field + assert "language" in result + assert isinstance(result["language"][0][0], str) + assert len(result["language"]) > 0 + + +@pytest.mark.skipif( + not is_espnet2_available(), + reason="espnet2 is not available", +) +def test_owsm_lid_metric_resampling(fixed_audio_8k): + """ + Test the OWSM LID metric with audio that needs resampling. + """ + config = { + "model_tag": "default", + "nbest": 3, + "use_gpu": False, + } + + metric = OwsmLidMetric(config) + result = metric.compute(fixed_audio_8k, metadata={"sample_rate": 8000}) + + # Check that result contains language field + assert "language" in result + assert isinstance(result["language"][0][0], str) + assert len(result["language"]) > 0 + + +@pytest.mark.skipif( + not is_espnet2_available(), + reason="espnet2 is not available", +) +def test_owsm_lid_metric_invalid_input(): + """ + Test the OWSM LID metric with invalid input. + """ + config = { + "model_tag": "default", + "nbest": 3, + "use_gpu": False, + } + + metric = OwsmLidMetric(config) + + # Test with None input + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif( + not is_espnet2_available(), + reason="espnet2 is not available", +) +def test_owsm_lid_metric_metadata(): + """ + Test the OWSM LID metric metadata. + """ + config = { + "model_tag": "default", + "nbest": 3, + "use_gpu": False, + } + + metric = OwsmLidMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "lid" + assert metadata.category.value == "independent" + assert metadata.metric_type.value == "list" + assert metadata.requires_reference is False + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert metadata.auto_install is False + assert "espnet2" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +def test_owsm_lid_metric_espnet2_not_available(): + """ + Test the OWSM LID metric when espnet2 is not available. + """ + # This test should be skipped if espnet2 is available + if is_espnet2_available(): + pytest.skip("espnet2 is available, skipping this test") + + config = { + "model_tag": "default", + "nbest": 3, + "use_gpu": False, + } + + with pytest.raises(ImportError, match="espnet2 is not properly installed"): + OwsmLidMetric(config) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav, fixed_audio_8k_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_audio_8k_wav).exists() diff --git a/test/test_metrics/test_pam.py b/test/test_metrics/test_pam.py new file mode 100644 index 0000000..b4fcd5d --- /dev/null +++ b/test/test_metrics/test_pam.py @@ -0,0 +1,337 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest +import torch +from packaging.version import parse as V + +from versa.utterance_metrics.pam import PamMetric, PAM, is_pam_available + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_audio_44k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 44.1kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio_44k.wav" + # Generate an audio file with a 150 Hz sine wave at 44.1kHz. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=44100, base_freq=150) + return audio_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_audio_44k(fixed_audio_44k_wav): + """ + Load the fixed 44.1kHz audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_44k_wav, sample_rate=44100) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif( + not is_pam_available(), + reason="PAM dependencies are not available", +) +@pytest.mark.parametrize( + "repro,use_gpu", + [ + (True, False), + (False, False), + ], +) +def test_pam_metric_basic(repro, use_gpu, fixed_audio): + """ + Test the PAM metric with basic configuration. + """ + config = { + "repro": repro, + "use_gpu": use_gpu, + "cache_dir": "test_cache/pam", + "text_model": "gpt2", + "text_len": 77, + "transformer_embed_dim": 768, + "audioenc_name": "HTSAT", + "out_emb": 768, + "sampling_rate": 44100, + "duration": 7, + "fmin": 50, + "fmax": 8000, + "n_fft": 1024, + "hop_size": 320, + "mel_bins": 64, + "window_size": 1024, + "d_proj": 1024, + "temperature": 0.003, + "num_classes": 527, + "batch_size": 1024, + "demo": False, + } + + metric = PamMetric(config) + result = metric.compute(fixed_audio, metadata={"sample_rate": 16000}) + + # Check that result contains pam_score field + assert "pam_score" in result + assert isinstance(result["pam_score"], (int, float, np.number)) + assert not np.isnan(result["pam_score"]) + assert not np.isinf(result["pam_score"]) + # PAM score should be between 0 and 1 + assert 0.0 <= result["pam_score"] <= 1.0 + + +@pytest.mark.skipif( + not is_pam_available(), + reason="PAM dependencies are not available", +) +def test_pam_metric_resampling(fixed_audio_44k): + """ + Test the PAM metric with audio that needs resampling. + """ + config = { + "repro": True, + "use_gpu": False, + "cache_dir": "test_cache/pam", + "text_model": "gpt2", + "text_len": 77, + "transformer_embed_dim": 768, + "audioenc_name": "HTSAT", + "out_emb": 768, + "sampling_rate": 44100, + "duration": 7, + "fmin": 50, + "fmax": 8000, + "n_fft": 1024, + "hop_size": 320, + "mel_bins": 64, + "window_size": 1024, + "d_proj": 1024, + "temperature": 0.003, + "num_classes": 527, + "batch_size": 1024, + "demo": False, + } + + metric = PamMetric(config) + result = metric.compute(fixed_audio_44k, metadata={"sample_rate": 44100}) + + # Check that result contains pam_score field + assert "pam_score" in result + assert isinstance(result["pam_score"], (int, float, np.number)) + assert not np.isnan(result["pam_score"]) + assert not np.isinf(result["pam_score"]) + # PAM score should be between 0 and 1 + assert 0.0 <= result["pam_score"] <= 1.0 + + +@pytest.mark.skipif( + not is_pam_available(), + reason="PAM dependencies are not available", +) +def test_pam_metric_invalid_input(): + """ + Test the PAM metric with invalid input. + """ + config = { + "repro": True, + "use_gpu": False, + "cache_dir": "test_cache/pam", + "text_model": "gpt2", + "text_len": 77, + "transformer_embed_dim": 768, + "audioenc_name": "HTSAT", + "out_emb": 768, + "sampling_rate": 44100, + "duration": 7, + "fmin": 50, + "fmax": 8000, + "n_fft": 1024, + "hop_size": 320, + "mel_bins": 64, + "window_size": 1024, + "d_proj": 1024, + "temperature": 0.003, + "num_classes": 527, + "batch_size": 1024, + "demo": False, + } + + metric = PamMetric(config) + + # Test with None input + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif( + not is_pam_available(), + reason="PAM dependencies are not available", +) +def test_pam_metric_metadata(): + """ + Test the PAM metric metadata. + """ + config = { + "repro": True, + "use_gpu": False, + "cache_dir": "test_cache/pam", + "text_model": "gpt2", + "text_len": 77, + "transformer_embed_dim": 768, + "audioenc_name": "HTSAT", + "out_emb": 768, + "sampling_rate": 44100, + "duration": 7, + "fmin": 50, + "fmax": 8000, + "n_fft": 1024, + "hop_size": 320, + "mel_bins": 64, + "window_size": 1024, + "d_proj": 1024, + "temperature": 0.003, + "num_classes": 527, + "batch_size": 1024, + "demo": False, + } + + metric = PamMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "pam" + assert metadata.category.value == "independent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is False + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert metadata.auto_install is False + assert "torch" in metadata.dependencies + assert "torchaudio" in metadata.dependencies + assert "transformers" in metadata.dependencies + assert "huggingface_hub" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +def test_pam_metric_not_available(): + """ + Test the PAM metric when PAM dependencies are not available. + """ + # This test should be skipped if PAM is available + if is_pam_available(): + pytest.skip("PAM dependencies are available, skipping this test") + + config = { + "repro": True, + "use_gpu": False, + "cache_dir": "test_cache/pam", + "text_model": "gpt2", + "text_len": 77, + "transformer_embed_dim": 768, + "audioenc_name": "HTSAT", + "out_emb": 768, + "sampling_rate": 44100, + "duration": 7, + "fmin": 50, + "fmax": 8000, + "n_fft": 1024, + "hop_size": 320, + "mel_bins": 64, + "window_size": 1024, + "d_proj": 1024, + "temperature": 0.003, + "num_classes": 527, + "batch_size": 1024, + "demo": False, + } + + with pytest.raises(RuntimeError, match="Failed to initialize PAM model"): + PamMetric(config) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav, fixed_audio_44k_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_audio_44k_wav).exists() diff --git a/test/test_metrics/test_pesq_score.py b/test/test_metrics/test_pesq_score.py new file mode 100644 index 0000000..7c28186 --- /dev/null +++ b/test/test_metrics/test_pesq_score.py @@ -0,0 +1,367 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest +import torch +from packaging.version import parse as V + +from versa.utterance_metrics.pesq_score import PesqMetric, is_pesq_available + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as ground truth. + This one uses a different base frequency (e.g., 300 Hz) so that the test + intentionally simulates a mismatch. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth.wav" + # Generate a ground truth file with a 300 Hz sine wave. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) + return gt_file + + +@pytest.fixture(scope="session") +def fixed_audio_8k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 8kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio_8k.wav" + # Generate an audio file with a 150 Hz sine wave at 8kHz. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=8000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_8k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 8kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth_8k.wav" + # Generate a ground truth file with a 300 Hz sine wave at 8kHz. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=8000, base_freq=300) + return gt_file + + +@pytest.fixture(scope="session") +def fixed_audio_22k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 22.05kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio_22k.wav" + # Generate an audio file with a 150 Hz sine wave at 22.05kHz. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=22050, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_22k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 22.05kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth_22k.wav" + # Generate a ground truth file with a 300 Hz sine wave at 22.05kHz. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=22050, base_freq=300) + return gt_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_ground_truth(fixed_ground_truth_wav): + """ + Load the fixed ground truth file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_wav) + + +@pytest.fixture(scope="session") +def fixed_audio_8k(fixed_audio_8k_wav): + """ + Load the fixed 8kHz audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_8k_wav, sample_rate=8000) + + +@pytest.fixture(scope="session") +def fixed_ground_truth_8k(fixed_ground_truth_8k_wav): + """ + Load the fixed 8kHz ground truth file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_8k_wav, sample_rate=8000) + + +@pytest.fixture(scope="session") +def fixed_audio_22k(fixed_audio_22k_wav): + """ + Load the fixed 22.05kHz audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_22k_wav, sample_rate=22050) + + +@pytest.fixture(scope="session") +def fixed_ground_truth_22k(fixed_ground_truth_22k_wav): + """ + Load the fixed 22.05kHz ground truth file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_22k_wav, sample_rate=22050) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif( + not is_pesq_available(), + reason="pesq is not available", +) +@pytest.mark.parametrize( + "sample_rate", + [8000, 16000], +) +def test_pesq_metric_basic(sample_rate, fixed_audio, fixed_ground_truth): + """ + Test the PESQ metric with basic configuration. + """ + config = {} + + metric = PesqMetric(config) + result = metric.compute( + fixed_audio, fixed_ground_truth, metadata={"sample_rate": sample_rate} + ) + + # Check that result contains pesq field + assert "pesq" in result + assert isinstance(result["pesq"], (int, float, np.number)) + assert not np.isnan(result["pesq"]) + # PESQ score should be between -0.5 and 4.5 + assert -0.5 <= result["pesq"] <= 4.5 + + +@pytest.mark.skipif( + not is_pesq_available(), + reason="pesq is not available", +) +def test_pesq_metric_8k_resampling(fixed_audio_8k, fixed_ground_truth_8k): + """ + Test the PESQ metric with 8kHz audio that needs resampling. + """ + config = {} + + metric = PesqMetric(config) + result = metric.compute( + fixed_audio_8k, fixed_ground_truth_8k, metadata={"sample_rate": 8000} + ) + + # Check that result contains pesq field + assert "pesq" in result + assert isinstance(result["pesq"], (int, float, np.number)) + assert not np.isnan(result["pesq"]) + # PESQ score should be between -0.5 and 4.5 + assert -0.5 <= result["pesq"] <= 4.5 + + +@pytest.mark.skipif( + not is_pesq_available(), + reason="pesq is not available", +) +def test_pesq_metric_22k_resampling(fixed_audio_22k, fixed_ground_truth_22k): + """ + Test the PESQ metric with 22.05kHz audio that needs resampling. + """ + config = {} + + metric = PesqMetric(config) + result = metric.compute( + fixed_audio_22k, fixed_ground_truth_22k, metadata={"sample_rate": 22050} + ) + + # Check that result contains pesq field + assert "pesq" in result + assert isinstance(result["pesq"], (int, float, np.number)) + assert not np.isnan(result["pesq"]) + # PESQ score should be between -0.5 and 4.5 + assert -0.5 <= result["pesq"] <= 4.5 + + +@pytest.mark.skipif( + not is_pesq_available(), + reason="pesq is not available", +) +def test_pesq_metric_invalid_input(): + """ + Test the PESQ metric with invalid input. + """ + config = {} + + metric = PesqMetric(config) + + # Test with None predictions + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, np.random.random(16000), metadata={"sample_rate": 16000}) + + # Test with None references + with pytest.raises(ValueError, match="Reference signal must be provided"): + metric.compute(np.random.random(16000), None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif( + not is_pesq_available(), + reason="pesq is not available", +) +def test_pesq_metric_metadata(): + """ + Test the PESQ metric metadata. + """ + config = {} + + metric = PesqMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "pesq" + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is True + assert metadata.requires_text is False + assert metadata.gpu_compatible is False + assert metadata.auto_install is False + assert "pesq" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +def test_pesq_metric_not_available(): + """ + Test the PESQ metric when pesq is not available. + """ + # This test should be skipped if pesq is available + if is_pesq_available(): + pytest.skip("pesq is available, skipping this test") + + config = {} + + with pytest.raises(ImportError, match="pesq is not properly installed"): + PesqMetric(config) + + +@pytest.mark.skipif( + not is_pesq_available(), + reason="pesq is not available", +) +def test_pesq_metric_same_audio(): + """ + Test the PESQ metric with identical audio (should give high score). + """ + config = {} + + metric = PesqMetric(config) + # Use the same audio for both prediction and reference + audio = np.random.random(16000) + result = metric.compute(audio, audio, metadata={"sample_rate": 16000}) + + # Check that result contains pesq field + assert "pesq" in result + assert isinstance(result["pesq"], (int, float, np.number)) + assert not np.isnan(result["pesq"]) + # PESQ score should be between -0.5 and 5 + assert -0.5 <= result["pesq"] <= 5 + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist( + fixed_audio_wav, + fixed_ground_truth_wav, + fixed_audio_8k_wav, + fixed_ground_truth_8k_wav, + fixed_audio_22k_wav, + fixed_ground_truth_22k_wav, +): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_ground_truth_wav).exists() + assert Path(fixed_audio_8k_wav).exists() + assert Path(fixed_ground_truth_8k_wav).exists() + assert Path(fixed_audio_22k_wav).exists() + assert Path(fixed_ground_truth_22k_wav).exists() diff --git a/test/test_pipeline/test_asr_match.py b/test/test_pipeline/test_asr_match.py index 2388eae..19ac420 100755 --- a/test/test_pipeline/test_asr_match.py +++ b/test/test_pipeline/test_asr_match.py @@ -4,25 +4,21 @@ import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.asr_matching import register_asr_match_metric -TEST_INFO = { - "asr_match_error_rate": 0.0, -} +TEST_INFO = {"asr_match_error_rate": 0.0} def info_update(): - # find files if os.path.isdir("test/test_samples/test2"): gen_files = find_files("test/test_samples/test2") # find reference file + gt_files = None if os.path.isdir("test/test_samples/test1"): gt_files = find_files("test/test_samples/test1") @@ -31,7 +27,15 @@ def info_update(): with open("egs/separate_metrics/asr_match.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register ASR-Match metric + registry = MetricRegistry() + register_asr_match_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gpu=False, @@ -39,17 +43,17 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): - # for sir" continue - # the plc mos is undeterministic if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": raise ValueError( "Value issue in the test case, might be some issue in scorer {}".format( diff --git a/test/test_pipeline/test_asvspoof.py b/test/test_pipeline/test_asvspoof.py index 2f76b39..e5ed31f 100755 --- a/test/test_pipeline/test_asvspoof.py +++ b/test/test_pipeline/test_asvspoof.py @@ -4,12 +4,10 @@ import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.asvspoof_score import register_asvspoof_metric TEST_INFO = { "asvspoof_score": 8.472739e-08, @@ -23,6 +21,7 @@ def info_update(): gen_files = find_files("test/test_samples/test2") # find reference file + gt_files = None if os.path.isdir("test/test_samples/test1"): gt_files = find_files("test/test_samples/test1") @@ -31,7 +30,15 @@ def info_update(): with open("egs/separate_metrics/asvspoof.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register ASVspoof metric + registry = MetricRegistry() + register_asvspoof_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gpu=False, @@ -39,12 +46,13 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): diff --git a/test/test_pipeline/test_audiobox_aesthetics.py b/test/test_pipeline/test_audiobox_aesthetics.py index a1c8f45..6a06d63 100755 --- a/test/test_pipeline/test_audiobox_aesthetics.py +++ b/test/test_pipeline/test_audiobox_aesthetics.py @@ -4,11 +4,11 @@ import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.audiobox_aesthetics_score import ( + register_audiobox_aesthetics_metric, ) TEST_INFO = { @@ -32,7 +32,15 @@ def info_update(): ) as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register AudioBox Aesthetics metric + registry = MetricRegistry() + register_audiobox_aesthetics_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=False, use_gpu=False, @@ -40,11 +48,13 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files=None, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files=None, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: # the plc mos is undeterministic diff --git a/test/test_pipeline/test_cdpam_distance.py b/test/test_pipeline/test_cdpam_distance.py new file mode 100644 index 0000000..2151c22 --- /dev/null +++ b/test/test_pipeline/test_cdpam_distance.py @@ -0,0 +1,71 @@ +import logging +import math +import os + +import yaml + +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.cdpam_distance import register_cdpam_distance_metric + +TEST_INFO = { + "cdpam_distance": 0.051460444927215576, +} + + +def info_update(): + + # find files + if os.path.isdir("test/test_samples/test2"): + gen_files = find_files("test/test_samples/test2") + + # find reference file + if os.path.isdir("test/test_samples/test1"): + gt_files = find_files("test/test_samples/test1") + + logging.info("The number of utterances = %d" % len(gen_files)) + + with open("egs/separate_metrics/cdpam_distance.yaml", "r", encoding="utf-8") as f: + score_config = yaml.full_load(f) + + # Create registry and register CDPAM distance metric + registry = MetricRegistry() + register_cdpam_distance_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( + score_config, + use_gt=(True if gt_files is not None else False), + use_gpu=False, + ) + + assert len(score_config) > 0, "no scoring function is provided" + + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files=gt_files, output_file=None, io="soundfile" + ) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) + + for key in summary: + if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): + # for sir" + continue + # the plc mos is undeterministic + if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": + raise ValueError( + "Value issue in the test case, might be some issue in scorer {}".format( + key + ) + ) + print("check successful", flush=True) + + +if __name__ == "__main__": + info_update() diff --git a/test/test_pipeline/test_chroma_alignment.py b/test/test_pipeline/test_chroma_alignment.py new file mode 100644 index 0000000..2ee3743 --- /dev/null +++ b/test/test_pipeline/test_chroma_alignment.py @@ -0,0 +1,79 @@ +import logging +import math +import os + +import yaml + +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.chroma_alignment import register_chroma_alignment_metric + +TEST_INFO = { + "chroma_stft_cosine_dtw": 0.8895886718828439, + "chroma_stft_euclidean_dtw": 45.091055545199, + "chroma_cqt_cosine_dtw": 1.1888872845493323, + "chroma_cqt_euclidean_dtw": 56.16051355647546, + "chroma_cens_cosine_dtw": 0.6962623125421354, + "chroma_cens_euclidean_dtw": 38.38994047138499, + "chroma_stft_cosine_dtw_raw": 8.895886718828438, + "chroma_stft_cosine_dtw_log": 20.508107511971517, +} + + +def info_update(): + + # find files + if os.path.isdir("test/test_samples/test2"): + gen_files = find_files("test/test_samples/test2") + + # find reference file + gt_files = None + if os.path.isdir("test/test_samples/test1"): + gt_files = find_files("test/test_samples/test1") + + logging.info("The number of utterances = %d" % len(gen_files)) + + with open("egs/separate_metrics/chroma_alignment.yaml", "r", encoding="utf-8") as f: + score_config = yaml.full_load(f) + + # Create registry and register Chroma Alignment metric + registry = MetricRegistry() + register_chroma_alignment_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( + score_config, + use_gt=(True if gt_files is not None else False), + use_gpu=False, + ) + + assert len(score_config) > 0, "no scoring function is provided" + + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" + ) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) + + for key in summary: + if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): + # for sir" + continue + # the plc mos is undeterministic + if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": + raise ValueError( + "Value issue in the test case, might be some issue in scorer {}".format( + key + ) + ) + print("check successful", flush=True) + + +if __name__ == "__main__": + info_update() diff --git a/test/test_pipeline/test_discrete_speech.py b/test/test_pipeline/test_discrete_speech.py new file mode 100644 index 0000000..9d4357e --- /dev/null +++ b/test/test_pipeline/test_discrete_speech.py @@ -0,0 +1,74 @@ +import logging +import math +import os + +import yaml + +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.discrete_speech import register_discrete_speech_metric + +TEST_INFO = { + "speech_bert": 0.9727544784545898, + "speech_bleu": 0.6699938983346256, + "speech_token_distance": 0.850506056080969, +} + + +def info_update(): + + # find files + if os.path.isdir("test/test_samples/test2"): + gen_files = find_files("test/test_samples/test2") + + # find reference file + gt_files = None + if os.path.isdir("test/test_samples/test1"): + gt_files = find_files("test/test_samples/test1") + + logging.info("The number of utterances = %d" % len(gen_files)) + + with open("egs/separate_metrics/discrete_speech.yaml", "r", encoding="utf-8") as f: + score_config = yaml.full_load(f) + + # Create registry and register Discrete Speech metric + registry = MetricRegistry() + register_discrete_speech_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( + score_config, + use_gt=(True if gt_files is not None else False), + use_gpu=False, + ) + + assert len(score_config) > 0, "no scoring function is provided" + + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" + ) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) + + for key in summary: + if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): + # for sir" + continue + # the plc mos is undeterministic + if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": + raise ValueError( + "Value issue in the test case, might be some issue in scorer {}".format( + key + ) + ) + print("check successful", flush=True) + + +if __name__ == "__main__": + info_update() diff --git a/test/test_pipeline/test_dpam_distance.py b/test/test_pipeline/test_dpam_distance.py new file mode 100644 index 0000000..9749cdc --- /dev/null +++ b/test/test_pipeline/test_dpam_distance.py @@ -0,0 +1,71 @@ +import logging +import math +import os + +import yaml + +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.dpam_distance import register_dpam_distance_metric + +TEST_INFO = { + "dpam_distance": 0.1500423550605774, +} + + +def info_update(): + + # find files + if os.path.isdir("test/test_samples/test2"): + gen_files = find_files("test/test_samples/test2") + + # find reference file + if os.path.isdir("test/test_samples/test1"): + gt_files = find_files("test/test_samples/test1") + + logging.info("The number of utterances = %d" % len(gen_files)) + + with open("egs/separate_metrics/dpam_distance.yaml", "r", encoding="utf-8") as f: + score_config = yaml.full_load(f) + + # Create registry and register DPAM distance metric + registry = MetricRegistry() + register_dpam_distance_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( + score_config, + use_gt=(True if gt_files is not None else False), + use_gpu=False, + ) + + assert len(score_config) > 0, "no scoring function is provided" + + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files=gt_files, output_file=None, io="soundfile" + ) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) + + for key in summary: + if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): + # for sir" + continue + # the plc mos is undeterministic + if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": + raise ValueError( + "Value issue in the test case, might be some issue in scorer {}".format( + key + ) + ) + print("check successful", flush=True) + + +if __name__ == "__main__": + info_update() diff --git a/test/test_pipeline/test_emo_similarity.py b/test/test_pipeline/test_emo_similarity.py index afa5cf5..fac8ba8 100755 --- a/test/test_pipeline/test_emo_similarity.py +++ b/test/test_pipeline/test_emo_similarity.py @@ -4,12 +4,10 @@ import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.emo_similarity import register_emo2vec_metric TEST_INFO = { "emotion_similarity": 0.9984976053237915, @@ -31,7 +29,15 @@ def info_update(): with open("egs/separate_metrics/emo_similarity.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register Emotion metric + registry = MetricRegistry() + register_emo2vec_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gpu=False, @@ -39,11 +45,13 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files=gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): diff --git a/test/test_pipeline/test_emo_vad.py b/test/test_pipeline/test_emo_vad.py new file mode 100644 index 0000000..45a4641 --- /dev/null +++ b/test/test_pipeline/test_emo_vad.py @@ -0,0 +1,69 @@ +import logging +import math +import os + +import yaml + +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.emo_vad import register_emo_vad_metric + +TEST_INFO = { + "arousal_emo_vad": 0.663333535194397, + "valence_emo_vad": 0.5060539245605469, + "dominance_emo_vad": 0.6355133056640625, +} + + +def info_update(): + + # find files + if os.path.isdir("test/test_samples/test2"): + gen_files = find_files("test/test_samples/test2") + + logging.info("The number of utterances = %d" % len(gen_files)) + + with open("egs/separate_metrics/emo_vad.yaml", "r", encoding="utf-8") as f: + score_config = yaml.full_load(f) + + # Create registry and register EmoVad metric + registry = MetricRegistry() + register_emo_vad_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( + score_config, + use_gt=False, + use_gpu=False, + ) + + assert len(score_config) > 0, "no scoring function is provided" + + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files=None, output_file=None, io="soundfile" + ) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) + + for key in summary: + if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): + # for sir" + continue + # the plc mos is undeterministic + if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": + raise ValueError( + "Value issue in the test case, might be some issue in scorer {}".format( + key + ) + ) + print("check successful", flush=True) + + +if __name__ == "__main__": + info_update() diff --git a/test/test_pipeline/test_lid.py b/test/test_pipeline/test_lid.py deleted file mode 100755 index 83ec124..0000000 --- a/test/test_pipeline/test_lid.py +++ /dev/null @@ -1,52 +0,0 @@ -import logging -import math -import os - -import yaml - -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) - - -def info_update(): - - # find files - if os.path.isdir("test/test_samples/test2"): - gen_files = find_files("test/test_samples/test2") - - # find reference file - if os.path.isdir("test/test_samples/test1"): - gt_files = find_files("test/test_samples/test1") - - logging.info("The number of utterances = %d" % len(gen_files)) - - with open("egs/separate_metrics/lid.yaml", "r", encoding="utf-8") as f: - score_config = yaml.full_load(f) - - score_modules = load_score_modules( - score_config, - use_gt=(True if gt_files is not None else False), - use_gpu=False, - ) - - assert len(score_config) > 0, "no scoring function is provided" - - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" - ) - print("Summary: {}".format((score_info), flush=True)) - - if abs(score_info[0]["language"][0][1] - 0.8865218162536621) > 1e-4: - raise ValueError( - "Value issue in the test case, might be some issue in scorer lanugage" - ) - - print("check successful", flush=True) - - -if __name__ == "__main__": - info_update() diff --git a/test/test_pipeline/test_nisqa.py b/test/test_pipeline/test_nisqa.py index 8e212d4..6964a40 100755 --- a/test/test_pipeline/test_nisqa.py +++ b/test/test_pipeline/test_nisqa.py @@ -1,15 +1,20 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Test pipeline for NISQA metric using the VersaScorer API.""" + import logging import math import os import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.nisqa import register_nisqa_metric TEST_INFO = { "nisqa_mos_pred": 0.4359583258628845, @@ -21,12 +26,12 @@ def info_update(): - # find files if os.path.isdir("test/test_samples/test2"): gen_files = find_files("test/test_samples/test2") # find reference file + gt_files = None if os.path.isdir("test/test_samples/test1"): gt_files = find_files("test/test_samples/test1") @@ -35,7 +40,15 @@ def info_update(): with open("egs/separate_metrics/nisqa.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register NISQA metric + registry = MetricRegistry() + register_nisqa_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gpu=False, @@ -43,14 +56,18 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: - if abs(TEST_INFO[key] - summary[key]) > 1e-4: + if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): + continue + if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": raise ValueError( "Value issue in the test case, might be some issue in scorer {}".format( key diff --git a/test/test_pipeline/test_nomad.py b/test/test_pipeline/test_nomad.py index 20e741c..838d859 100755 --- a/test/test_pipeline/test_nomad.py +++ b/test/test_pipeline/test_nomad.py @@ -1,26 +1,31 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Test pipeline for NOMAD metric using the VersaScorer API.""" + import logging import math import os import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.nomad import register_nomad_metric TEST_INFO = {"nomad": 0.0336} def info_update(): - # find files if os.path.isdir("test/test_samples/test2"): gen_files = find_files("test/test_samples/test2") # find reference file + gt_files = None if os.path.isdir("test/test_samples/test1"): gt_files = find_files("test/test_samples/test1") @@ -29,7 +34,15 @@ def info_update(): with open("egs/separate_metrics/nomad.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register NOMAD metric + registry = MetricRegistry() + register_nomad_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gpu=False, @@ -37,17 +50,17 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): - # for sir" continue - # the plc mos is undeterministic if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": raise ValueError( "Value issue in the test case, might be some issue in scorer {}".format( diff --git a/test/test_pipeline/test_noresqa.py b/test/test_pipeline/test_noresqa.py index a3ebb37..8b78719 100755 --- a/test/test_pipeline/test_noresqa.py +++ b/test/test_pipeline/test_noresqa.py @@ -1,28 +1,31 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Test pipeline for NORESQA metric using the VersaScorer API.""" + import logging import math import os import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.noresqa import register_noresqa_metric -TEST_INFO = { - "noresqa": 12.010879979211092, # need to be updated -} +TEST_INFO = {"noresqa_mos": 1.051746129989624} def info_update(): - # find files if os.path.isdir("test/test_samples/test2"): gen_files = find_files("test/test_samples/test2") # find reference file + gt_files = None if os.path.isdir("test/test_samples/test1"): gt_files = find_files("test/test_samples/test1") @@ -31,7 +34,15 @@ def info_update(): with open("egs/separate_metrics/noresqa.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register NORESQA metric + registry = MetricRegistry() + register_noresqa_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gpu=False, @@ -39,17 +50,17 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): - # for sir" continue - # the plc mos is undeterministic if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": raise ValueError( "Value issue in the test case, might be some issue in scorer {}".format( diff --git a/test/test_pipeline/test_owsm_lid.py b/test/test_pipeline/test_owsm_lid.py new file mode 100755 index 0000000..0588099 --- /dev/null +++ b/test/test_pipeline/test_owsm_lid.py @@ -0,0 +1,64 @@ +import logging +import math +import os + +import yaml + +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.owsm_lid import register_owsm_lid_metric + +TEST_INFO = {"language": 0.8865218162536621} + + +def info_update(): + # find files + if os.path.isdir("test/test_samples/test2"): + gen_files = find_files("test/test_samples/test2") + + # find reference file + gt_files = None + if os.path.isdir("test/test_samples/test1"): + gt_files = find_files("test/test_samples/test1") + + logging.info("The number of utterances = %d" % len(gen_files)) + + with open("egs/separate_metrics/lid.yaml", "r", encoding="utf-8") as f: + score_config = yaml.full_load(f) + + # Create registry and register OWSM LID metric + registry = MetricRegistry() + register_owsm_lid_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( + score_config, + use_gt=(True if gt_files is not None else False), + use_gpu=False, + ) + + assert len(score_config) > 0, "no scoring function is provided" + + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" + ) + + print("Scorer score_info: {}".format(score_info)) + + best_hyper = score_info[0]["language"][0][1] + if abs(best_hyper - TEST_INFO["language"]) > 1e-4: + raise ValueError( + "Value issue in the test case, might be some issue in scorer {}".format( + "language" + ) + ) + print("check successful", flush=True) + + +if __name__ == "__main__": + info_update() diff --git a/test/test_pipeline/test_pam.py b/test/test_pipeline/test_pam.py index 315a939..a3dd9f7 100755 --- a/test/test_pipeline/test_pam.py +++ b/test/test_pipeline/test_pam.py @@ -4,46 +4,56 @@ import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.pam import register_pam_metric -TEST_INFO = {"pam_score": 0.01386283989995718} +TEST_INFO = {"pam_score": 0.3535262942314148} def info_update(): - # find files if os.path.isdir("test/test_samples/test2"): gen_files = find_files("test/test_samples/test2") + # find reference file + gt_files = None + if os.path.isdir("test/test_samples/test1"): + gt_files = find_files("test/test_samples/test1") + logging.info("The number of utterances = %d" % len(gen_files)) with open("egs/separate_metrics/pam.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register PAM metric + registry = MetricRegistry() + register_pam_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, - use_gt=False, + use_gt=(True if gt_files is not None else False), use_gpu=False, ) assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): - # for sir" continue - # the plc mos is undeterministic if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": raise ValueError( "Value issue in the test case, might be some issue in scorer {}".format( diff --git a/test/test_pipeline/test_pesq_score.py b/test/test_pipeline/test_pesq_score.py new file mode 100644 index 0000000..f658f7d --- /dev/null +++ b/test/test_pipeline/test_pesq_score.py @@ -0,0 +1,65 @@ +import logging +import math +import os + +import yaml + +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.pesq_score import register_pesq_metric + +TEST_INFO = {"pesq": 1.5722705125808716} # Expected PESQ score for test audio + + +def info_update(): + # find files + if os.path.isdir("test/test_samples/test2"): + gen_files = find_files("test/test_samples/test2") + + # find reference file + gt_files = None + if os.path.isdir("test/test_samples/test1"): + gt_files = find_files("test/test_samples/test1") + + logging.info("The number of utterances = %d" % len(gen_files)) + + with open("egs/separate_metrics/pesq.yaml", "r", encoding="utf-8") as f: + score_config = yaml.full_load(f) + + # Create registry and register PESQ metric + registry = MetricRegistry() + register_pesq_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( + score_config, + use_gt=(True if gt_files is not None else False), + use_gpu=False, + ) + + assert len(score_config) > 0, "no scoring function is provided" + + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" + ) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) + + for key in summary: + if abs(TEST_INFO[key] - summary[key]) > 1e-4: + raise ValueError( + "Value issue in the test case, might be some issue in scorer {}".format( + key + ) + ) + print("check successful", flush=True) + + +if __name__ == "__main__": + info_update() diff --git a/test/test_pipeline/test_srmr.py b/test/test_pipeline/test_srmr.py index b97866e..184e39c 100755 --- a/test/test_pipeline/test_srmr.py +++ b/test/test_pipeline/test_srmr.py @@ -4,16 +4,12 @@ import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.srmr import register_srmr_metric -TEST_INFO = { - "srmr": 0.6123816687905584, -} +TEST_INFO = {"srmr": 0.6123816687905584} def info_update(): @@ -23,6 +19,7 @@ def info_update(): gen_files = find_files("test/test_samples/test2") # find reference file + gt_files = None if os.path.isdir("test/test_samples/test1"): gt_files = find_files("test/test_samples/test1") @@ -31,7 +28,15 @@ def info_update(): with open("egs/separate_metrics/srmr.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register SRMR metric + registry = MetricRegistry() + register_srmr_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gpu=False, @@ -39,18 +44,16 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: - if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): - # for sir" - continue - # the plc mos is undeterministic - if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": + if abs(TEST_INFO[key] - summary[key]) > 1e-4: raise ValueError( "Value issue in the test case, might be some issue in scorer {}".format( key diff --git a/tools/install_fairseq.sh b/tools/install_fairseq.sh index 005f3aa..930976f 100755 --- a/tools/install_fairseq.sh +++ b/tools/install_fairseq.sh @@ -5,7 +5,7 @@ REPO_OWNER="ftshijt" REPO_NAME="fairseq" REPO_PATH="$REPO_OWNER/$REPO_NAME" BRANCH="versa" -EXPECTED_COMMIT_ID="612be207e0afe60859ec393608ef89bba0e5246c" +EXPECTED_COMMIT_ID="7c814e9580e24f69bd6198b400ec12bc3f90fd51" # Old version: EXPECTED_COMMIT_ID="0e35caead74528f04e741986b78ff0b4b543dbe6" # Function to check if repository exists diff --git a/versa/__init__.py b/versa/__init__.py index 69eaa3e..84fec4f 100644 --- a/versa/__init__.py +++ b/versa/__init__.py @@ -2,13 +2,13 @@ __version__ = "0.0.1" # noqa: F401 -from versa.sequence_metrics.mcd_f0 import mcd_f0 -from versa.sequence_metrics.signal_metric import signal_metric +# from versa.sequence_metrics.mcd_f0 import McdF0Metric, register_mcd_f0_metric +# from versa.sequence_metrics.signal_metric import SignalMetric, register_signal_metric try: from versa.utterance_metrics.discrete_speech import ( - discrete_speech_metric, - discrete_speech_setup, + DiscreteSpeechMetric, + register_discrete_speech_metric, ) except ImportError: logging.info( @@ -19,96 +19,113 @@ "Issues detected in discrete speech metrics, please double check the environment." ) -from versa.utterance_metrics.pseudo_mos import pseudo_mos_metric, pseudo_mos_setup +# from versa.utterance_metrics.pseudo_mos import PseudoMosMetric, register_pseudo_mos_metric -try: - from versa.utterance_metrics.pesq_score import pesq_metric -except ImportError: - logging.info("Please install pesq with `pip install pesq` and retry") +# try: +# from versa.utterance_metrics.pesq_score import PesqMetric, register_pesq_metric +# except ImportError: +# logging.info("Please install pesq with `pip install pesq` and retry") -try: - from versa.utterance_metrics.stoi import stoi_metric, estoi_metric -except ImportError: - logging.info("Please install pystoi with `pip install pystoi` and retry") +# try: +# from versa.utterance_metrics.stoi import StoiMetric, register_stoi_metric +# except ImportError: +# logging.info("Please install pystoi with `pip install pystoi` and retry") -try: - from versa.utterance_metrics.speaker import speaker_metric, speaker_model_setup -except ImportError: - logging.info("Please install espnet with `pip install espnet` and retry") +# try: +# from versa.utterance_metrics.speaker import SpeakerMetric, register_speaker_metric +# except ImportError: +# logging.info("Please install espnet with `pip install espnet` and retry") -try: - from versa.utterance_metrics.singer import singer_metric, singer_model_setup -except ImportError: - logging.info("Please install ...") +# try: +# from versa.utterance_metrics.singer import SingerMetric, register_singer_metric +# except ImportError: +# logging.info("Please install ...") -try: - from versa.utterance_metrics.visqol_score import visqol_metric, visqol_setup -except ImportError: - logging.info( - "Please install visqol follow https://github.com/google/visqol and retry" - ) +# try: +# from versa.utterance_metrics.visqol_score import VisqolMetric, register_visqol_metric +# except ImportError: +# logging.info( +# "Please install visqol follow https://github.com/google/visqol and retry" +# ) -from versa.corpus_metrics.espnet_wer import espnet_levenshtein_metric, espnet_wer_setup -from versa.corpus_metrics.fad import fad_scoring, fad_setup -from versa.corpus_metrics.owsm_wer import owsm_levenshtein_metric, owsm_wer_setup -from versa.corpus_metrics.whisper_wer import ( - whisper_levenshtein_metric, - whisper_wer_setup, +# from versa.corpus_metrics.espnet_wer import EspnetWerMetric, register_espnet_wer_metric +# from versa.corpus_metrics.fad import FadMetric, register_fad_metric +# from versa.corpus_metrics.owsm_wer import OwsmWerMetric, register_owsm_wer_metric +# from versa.corpus_metrics.whisper_wer import ( +# WhisperWerMetric, +# register_whisper_wer_metric +# ) +from versa.utterance_metrics.asr_matching import ( + ASRMatchMetric, + register_asr_match_metric, ) -from versa.utterance_metrics.asr_matching import asr_match_metric, asr_match_setup from versa.utterance_metrics.audiobox_aesthetics_score import ( - audiobox_aesthetics_score, - audiobox_aesthetics_setup, + AudioBoxAestheticsMetric, + register_audiobox_aesthetics_metric, ) -from versa.utterance_metrics.emotion import emo2vec_setup, emo_sim -from versa.utterance_metrics.nomad import nomad, nomad_setup -from versa.utterance_metrics.noresqa import noresqa_metric, noresqa_model_setup -from versa.utterance_metrics.owsm_lid import language_id, owsm_lid_model_setup -from versa.utterance_metrics.pysepm import pysepm_metric -from versa.utterance_metrics.qwen2_audio import ( - qwen2_channel_type_metric, - qwen2_language_metric, - qwen2_laughter_crying_metric, - qwen2_model_setup, - qwen2_overlapping_speech_metric, - qwen2_pitch_range_metric, - qwen2_recording_quality_metric, - qwen2_speaker_age_metric, - qwen2_speaker_count_metric, - qwen2_speaker_gender_metric, - qwen2_speaking_style_metric, - qwen2_speech_background_environment_metric, - qwen2_speech_clarity_metric, - qwen2_speech_emotion_metric, - qwen2_speech_impairment_metric, - qwen2_speech_purpose_metric, - qwen2_speech_rate_metric, - qwen2_speech_register_metric, - qwen2_speech_volume_level_metric, - qwen2_vocabulary_complexity_metric, - qwen2_voice_pitch_metric, - qwen2_voice_type_metric, - qwen2_singing_technique_metric, +from versa.utterance_metrics.emo_similarity import ( + Emo2vecMetric, + register_emo2vec_metric, ) -from versa.utterance_metrics.qwen_omni import ( - qwen_omni_model_setup, - qwen_omni_singing_technique_metric, +from versa.utterance_metrics.nomad import NomadMetric, register_nomad_metric +from versa.utterance_metrics.noresqa import NoresqaMetric, register_noresqa_metric +from versa.utterance_metrics.owsm_lid import OwsmLidMetric, register_owsm_lid_metric + +# from versa.utterance_metrics.pysepm import PysepmMetric, register_pysepm_metric +# from versa.utterance_metrics.qwen2_audio import ( +# Qwen2ChannelTypeMetric, +# Qwen2LanguageMetric, +# Qwen2LaughterCryingMetric, +# Qwen2ModelSetup, +# Qwen2OverlappingSpeechMetric, +# Qwen2PitchRangeMetric, +# Qwen2RecordingQualityMetric, +# Qwen2SpeakerAgeMetric, +# Qwen2SpeakerCountMetric, +# Qwen2SpeakerGenderMetric, +# Qwen2SpeakingStyleMetric, +# Qwen2SpeechBackgroundEnvironmentMetric, +# Qwen2SpeechClarityMetric, +# Qwen2SpeechEmotionMetric, +# Qwen2SpeechImpairmentMetric, +# Qwen2SpeechPurposeMetric, +# Qwen2SpeechRateMetric, +# Qwen2SpeechRegisterMetric, +# Qwen2SpeechVolumeLevelMetric, +# Qwen2VocabularyComplexityMetric, +# Qwen2VoicePitchMetric, +# Qwen2VoiceTypeMetric, +# Qwen2SingingTechniqueMetric, +# ) +# from versa.utterance_metrics.qwen_omni import ( +# QwenOmniMetric, +# register_qwen_omni_metric +# ) +# from versa.utterance_metrics.scoreq import ( +# ScoreqMetric, +# register_scoreq_metric +# ) +# from versa.utterance_metrics.se_snr import SeSnrMetric, register_se_snr_metric +# from versa.utterance_metrics.sheet_ssqa import SheetSsqaMetric, register_sheet_ssqa_metric +# from versa.utterance_metrics.speaking_rate import ( +# SpeakingRateMetric, +# register_speaking_rate_metric +# ) +# from versa.utterance_metrics.squim import SquimMetric, register_squim_metric +from versa.utterance_metrics.srmr import SRMRMetric, register_srmr_metric +from versa.utterance_metrics.chroma_alignment import ( + ChromaAlignmentMetric, + register_chroma_alignment_metric, ) -from versa.utterance_metrics.scoreq import ( - scoreq_nr, - scoreq_nr_setup, - scoreq_ref, - scoreq_ref_setup, +from versa.utterance_metrics.dpam_distance import ( + DpamDistanceMetric, + register_dpam_distance_metric, ) -from versa.utterance_metrics.se_snr import se_snr, se_snr_setup -from versa.utterance_metrics.sheet_ssqa import sheet_ssqa, sheet_ssqa_setup -from versa.utterance_metrics.speaking_rate import ( - speaking_rate_metric, - speaking_rate_model_setup, +from versa.utterance_metrics.cdpam_distance import ( + CdpamDistanceMetric, + register_cdpam_distance_metric, ) -from versa.utterance_metrics.squim import squim_metric, squim_metric_no_ref -from versa.utterance_metrics.srmr import srmr_metric -from versa.utterance_metrics.chroma_alignment import chroma_metric -from versa.utterance_metrics.dpam_distance import dpam_metric, dpam_model_setup -from versa.utterance_metrics.cdpam_distance import cdpam_metric, cdpam_model_setup -from versa.utterance_metrics.vqscore import vqscore_metric, vqscore_setup + +# from versa.utterance_metrics.vqscore import VqscoreMetric, register_vqscore_metric +from versa.utterance_metrics.nisqa import NisqaMetric, register_nisqa_metric +from versa.utterance_metrics.pam import PamMetric, register_pam_metric diff --git a/versa/bin/scorer.py b/versa/bin/scorer.py index 2431006..99392df 100644 --- a/versa/bin/scorer.py +++ b/versa/bin/scorer.py @@ -13,11 +13,8 @@ from versa.scorer_shared import ( audio_loader_setup, - corpus_scoring, - list_scoring, - load_corpus_modules, - load_score_modules, - load_summary, + VersaScorer, + compute_summary, ) @@ -141,47 +138,58 @@ def main(): with open(args.score_config, "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Initialize VersaScorer + scorer = VersaScorer() + + # Load utterance-level metrics + utterance_metrics = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gt_text=(True if text_info is not None else False), use_gpu=args.use_gpu, ) - if len(score_modules) > 0: - score_info = list_scoring( + # Perform utterance-level scoring + if len(utterance_metrics.metrics) > 0: + score_info = scorer.score_utterances( gen_files, - score_modules, + utterance_metrics, gt_files, text_info, output_file=args.output_file, io=args.io, ) - logging.info("Summary: {}".format(load_summary(score_info))) + logging.info("Summary: {}".format(compute_summary(score_info))) else: logging.info("No utterance-level scoring function is provided.") - corpus_score_modules = load_corpus_modules( + # Load corpus-level metrics (distributional metrics) + corpus_metrics = scorer.load_metrics( score_config, + use_gt=(True if gt_files is not None else False), + use_gt_text=(True if text_info is not None else False), use_gpu=args.use_gpu, - cache_folder=args.cache_folder, - io=args.io, ) - assert ( - len(corpus_score_modules) > 0 or len(score_modules) > 0 - ), "no scoring function is provided" - if len(corpus_score_modules) > 0: - corpus_score_info = corpus_scoring( - args.pred, - corpus_score_modules, - args.gt, + + # Filter for corpus-level metrics and perform corpus scoring + from versa.definition import MetricCategory + + corpus_suite = corpus_metrics.filter_by_category(MetricCategory.DISTRIBUTIONAL) + if len(corpus_suite.metrics) > 0: + corpus_score_info = scorer.score_corpus( + gen_files, + corpus_suite, + gt_files, text_info, - output_file=args.output_file + ".corpus", + output_file=args.output_file + ".corpus" if args.output_file else None, ) logging.info("Corpus Summary: {}".format(corpus_score_info)) else: logging.info("No corpus-level scoring function is provided.") - return + + # Ensure at least one scoring function is provided + if len(utterance_metrics.metrics) == 0 and len(corpus_suite.metrics) == 0: + raise ValueError("No scoring function is provided") if __name__ == "__main__": diff --git a/versa/definition.py b/versa/definition.py new file mode 100644 index 0000000..1877193 --- /dev/null +++ b/versa/definition.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + + +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Any, Union +from dataclasses import dataclass +from enum import Enum +import logging + + +class MetricCategory(Enum): + INDEPENDENT = "independent" + DEPENDENT = "dependent" + NON_MATCH = "non_match" + DISTRIBUTIONAL = "distributional" + + +class MetricType(Enum): + STRING = "string" + FLOAT = "float" + INT = "int" + BOOL = "bool" + LIST = "list" + DICT = "dict" + TUPLE = "tuple" + ARRAY = "array" + TIME = "time" + + +@dataclass +class MetricMetadata: + name: str + category: MetricCategory + metric_type: MetricType + requires_reference: bool + requires_text: bool + gpu_compatible: bool + auto_install: bool + dependencies: List[str] + description: str + paper_reference: Optional[str] = None + implementation_source: Optional[str] = None + + +class MetricRegistry: + """Centralized registry for all metrics with automatic discovery.""" + + def __init__(self): + self._metrics: Dict[str, type] = {} + self._metadata: Dict[str, MetricMetadata] = {} + self._aliases: Dict[str, str] = {} + + def register( + self, metric_class: type, metadata: MetricMetadata, aliases: List[str] = None + ): + """Register a metric with its metadata.""" + self._metrics[metadata.name] = metric_class + self._metadata[metadata.name] = metadata + + # Register aliases + if aliases: + for alias in aliases: + self._aliases[alias] = metadata.name + + def get_metric(self, name: str) -> type: + """Get metric class by name or alias.""" + real_name = self._aliases.get(name, name) + return self._metrics.get(real_name) + + def get_metadata(self, name: str) -> MetricMetadata: + """Get metric metadata by name or alias.""" + real_name = self._aliases.get(name, name) + return self._metadata.get(real_name) + + def list_metrics( + self, category: MetricCategory = None, metric_type: MetricType = None + ) -> List[str]: + """List available metrics with optional filtering.""" + metrics = [] + for name, metadata in self._metadata.items(): + if category and metadata.category != category: + continue + if metric_type and metadata.metric_type != metric_type: + continue + metrics.append(name) + return sorted(metrics) + + +class BaseMetric(ABC): + """Abstract base class for all metrics.""" + + def __init__(self, config: Dict[str, Any] = None): + self.config = config or {} + self.logger = logging.getLogger(self.__class__.__name__) + self._setup() + + @abstractmethod + def _setup(self): + """Initialize metric-specific components.""" + pass + + @abstractmethod + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Any: + """Compute the metric score.""" + pass + + @abstractmethod + def get_metadata(self) -> MetricMetadata: + """Return metric metadata.""" + pass + + def validate_inputs(self, predictions: Any, references: Any = None) -> bool: + """Validate input data before computation.""" + return True + + def preprocess(self, data: Any) -> Any: + """Preprocess data before metric computation.""" + return data + + def postprocess(self, scores: Any) -> Any: + """Postprocess scores after computation.""" + return scores + + +class GPUMetric(BaseMetric): + """Base class for GPU-compatible metrics.""" + + def __init__(self, config: Dict[str, Any] = None, device: str = "cuda"): + self.device = device + super().__init__(config) + + def to_device(self, data: Any) -> Any: + """Move data to specified device.""" + if hasattr(data, "to"): + return data.to(self.device) + return data + + +class MetricFactory: + """Factory for creating metric instances with dependency management.""" + + def __init__(self, registry: MetricRegistry): + self.registry = registry + self._dependency_cache = {} + + def create_metric(self, name: str, config: Dict[str, Any] = None) -> BaseMetric: + """Create a metric instance with proper dependency resolution.""" + metadata = self.registry.get_metadata(name) + metric_class = self.registry.get_metric(name) + + if not metric_class: + raise ValueError(f"Metric '{name}' not found in registry") + + # Check and install dependencies + self._ensure_dependencies(metadata.dependencies) + + return metric_class(config) + + def create_metric_suite( + self, metric_names: List[str], config: Dict[str, Any] = None + ) -> "MetricSuite": + """Create a suite of metrics.""" + metrics = {} + for name in metric_names: + metrics[name] = self.create_metric(name, config.get(name, {})) + return MetricSuite(metrics) + + def _ensure_dependencies(self, dependencies: List[str]): + """Ensure all dependencies are available.""" + for dep in dependencies: + if dep not in self._dependency_cache: + try: + __import__(dep) + self._dependency_cache[dep] = True + except ImportError: + self.logger.warning(f"Dependency '{dep}' not available") + self._dependency_cache[dep] = False + + +class MetricSuite: + """Container for multiple metrics with batch processing capabilities.""" + + def __init__(self, metrics: Dict[str, BaseMetric]): + self.metrics = metrics + self.logger = logging.getLogger(self.__class__.__name__) + + def compute_all( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Any]: + """Compute all metrics in the suite.""" + results = {} + for name, metric in self.metrics.items(): + try: + results[name] = metric.compute(predictions, references, metadata) + except Exception as e: + self.logger.error(f"Error computing metric '{name}': {e}") + results[name] = None + return results + + def compute_parallel( + self, + predictions: Any, + references: Any = None, + metadata: Dict[str, Any] = None, + n_workers: int = 4, + ) -> Dict[str, Any]: + """Compute metrics in parallel.""" + # Implementation for parallel metric computation + pass + + def filter_by_category(self, category: MetricCategory) -> "MetricSuite": + """Filter metrics by category.""" + filtered_metrics = { + name: metric + for name, metric in self.metrics.items() + if metric.get_metadata().category == category + } + return MetricSuite(filtered_metrics) diff --git a/versa/metrics.py b/versa/metrics.py index 067b6d2..377383f 100644 --- a/versa/metrics.py +++ b/versa/metrics.py @@ -3,6 +3,12 @@ # Copyright 2025 Jiatong Shi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +DICT_METRIC = [ + "match_details", + "language", +] + STR_METRIC = [ "vad_info", "language", @@ -58,8 +64,8 @@ "audiobox_aesthetics_CU", "audiobox_aesthetics_PC", "audiobox_aesthetics_PQ", - "cdpam", - "dpam", + "cdpam_distance", + "dpam_distance", "mcd", "f0_corr", "f0_rmse", @@ -126,7 +132,29 @@ "clap_score", "apa", "pysepm_llr", + "chroma_stft_cosine_dtw", + "chroma_stft_euclidean_dtw", + "chroma_cqt_cosine_dtw", + "chroma_cqt_euclidean_dtw", + "chroma_cens_cosine_dtw", + "chroma_cens_euclidean_dtw", + "chroma_stft_cosine_dtw_raw", + "chroma_stft_cosine_dtw_log", + "speech_bert", + "speech_bleu", + "speech_token_distance", + "arousal_emo_vad", + "valence_emo_vad", + "dominance_emo_vad", "dnsmos_pro_bvcc", "dnsmos_pro_nisqa", "dnsmos_pro_vcc2018", + "nisqa_mos_pred", + "nisqa_noi_pred", + "nisqa_dis_pred", + "nisqa_col_pred", + "nisqa_loud_pred", + "noresqa_mos", + "noresqa_score", + "pam_score", ] diff --git a/versa/scorer_shared.py b/versa/scorer_shared.py index 2db82a5..04403bd 100644 --- a/versa/scorer_shared.py +++ b/versa/scorer_shared.py @@ -2,15 +2,26 @@ # Copyright 2024 Jiatong Shi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import logging +import logging import json import kaldiio import librosa import soundfile as sf import yaml +from typing import Dict, List, Optional, Any, Union from tqdm import tqdm +from versa.definition import ( + BaseMetric, + GPUMetric, + MetricRegistry, + MetricFactory, + MetricSuite, + MetricCategory, + MetricType, + MetricMetadata, +) from versa.metrics import STR_METRIC, NUM_METRIC from versa.utils_shared import ( check_all_same, @@ -41,1292 +52,286 @@ def audio_loader_setup(audio, io): return audio_files -def load_score_modules(score_config, use_gt=True, use_gt_text=False, use_gpu=False): - assert score_config, "no scoring function is provided" - score_modules = {} - for config in score_config: - print(config, flush=True) - if config["name"] == "mcd_f0": - if not use_gt: - logging.warning( - "Cannot use mcd/f0 metrics because no gt audio is provided" - ) - continue - - logging.info("Loading MCD & F0 evaluation...") - from versa import mcd_f0 - - score_modules["mcd_f0"] = { - "module": mcd_f0, - "args": { - "f0min": config.get("f0min", 0), - "f0max": config.get("f0max", 24000), - "mcep_shift": config.get("mcep_shift", 5), - "mcep_fftl": config.get("mcep_fftl", 1024), - "mcep_dim": config.get("mcep_dim", 39), - "mcep_alpha": config.get("mcep_alpha", 0.466), - "seq_mismatch_tolerance": config.get("seq_mismatch_tolerance", 0.1), - "power_threshold": config.get("power_threshold", -20), - "dtw": config.get("dtw", False), - }, - } - logging.info("Initiate MCD & F0 evaluation successfully.") - - elif config["name"] == "signal_metric": - if not use_gt: - logging.warning( - "Cannot use signal metric because no gt audio is provided" - ) - continue +class ScoreProcessor: + """Handles batch processing and caching of scores.""" - logging.info("Loading signal metric evaluation...") - from versa import signal_metric + def __init__(self, metric_suite: MetricSuite, output_file: Optional[str] = None): + self.metric_suite = metric_suite + self.output_file = output_file + self.logger = logging.getLogger(self.__class__.__name__) - score_modules["signal_metric"] = {"module": signal_metric} - logging.info("Initiate signal metric evaluation successfully.") - - elif config["name"] == "warpq": - if not use_gt: - logging.warning("Cannot use warpq because no gt audio is provided") - continue - - logging.info("Loading WARPQ metric evaluation...") - from versa.sequence_metrics.warpq import warpq, warpq_setup - - score_modules["warpq"] = {"model": warpq_setup(), "module": warpq} - logging.info("Initiate WARP-Q metric...") + if output_file: + self.file_handle = open(output_file, "w", encoding="utf-8") + else: + self.file_handle = None - elif config["name"] == "nisqa": + def process_batch(self, cache_info: List[tuple]) -> List[Dict[str, Any]]: + """Process a batch of cached utterance information.""" + batch_score_info = [] + for utt_info in cache_info: + key, gen_wav, gt_wav, gen_sr, text = utt_info + utt_score = {"key": key} - logging.info("Loading NISQA evaluation...") - from versa.utterance_metrics.nisqa import nisqa_metric, nisqa_model_setup + try: + # Prepare metadata for metric computation + metadata = { + "key": key, + "sample_rate": gen_sr, + "text": text, + "general_cache": {"whisper_hyp_text": None}, + } - # Load the NISQA model - nisqa_model = nisqa_model_setup( - nisqa_model_path=config.get( - "model_path", "./tools/NISQA/weights/nisqa.tar" - ), - use_gpu=use_gpu, - ) - score_modules["nisqa"] = { - "module": nisqa_metric, - "model": nisqa_model, - } - logging.info("Initiate NISQA evaluation successfully.") - - elif config["name"] == "discrete_speech": - if not use_gt: - logging.warning( - "Cannot use discrete speech metric because no gt audio is provided" + # Compute all metrics + scores = self.metric_suite.compute_all( + predictions=gen_wav, references=gt_wav, metadata=metadata ) - continue - logging.info("Loading discrete speech evaluation...") - from versa import discrete_speech_metric, discrete_speech_setup + # Flatten the metric results + for metric_name, metric_results in scores.items(): + if isinstance(metric_results, dict): + utt_score.update(metric_results) + else: + utt_score[metric_name] = metric_results - score_modules["discrete_speech"] = { - "module": discrete_speech_metric, - "model": discrete_speech_setup(use_gpu=use_gpu), - } - logging.info("Initiate discrete speech evaluation successfully.") + except Exception as e: + self.logger.error(f"Error processing file: {key} with error {e}") - elif config["name"] == "pseudo_mos": - logging.info("Loading pseudo MOS evaluation...") - from versa import pseudo_mos_metric, pseudo_mos_setup + batch_score_info.append(utt_score) - predictor_dict, predictor_fs = pseudo_mos_setup( - use_gpu=use_gpu, - predictor_types=config.get("predictor_types", ["utmos"]), - predictor_args=config.get("predictor_args", {}), - ) - score_modules["pseudo_mos"] = { - "module": pseudo_mos_metric, - "args": { - "predictor_dict": predictor_dict, - "predictor_fs": predictor_fs, - "use_gpu": use_gpu, - }, - } - logging.info("Initiate pseudo MOS evaluation successfully.") - - elif config["name"] == "pesq": - if not use_gt: - logging.warning( - "Cannot use pesq metric because no gt audio is provided" + if self.file_handle: + printable_result = json.dumps( + utt_score, default=default_numpy_serializer ) - continue + self.file_handle.write(f"{printable_result}\n") - logging.info("Loading pesq evaluation...") - from versa import pesq_metric + return batch_score_info - score_modules["pesq"] = {"module": pesq_metric} - logging.info("Initiate pesq evaluation successfully.") - - elif config["name"] == "stoi": - if not use_gt: - logging.warning( - "Cannot use stoi metric because no gt audio is provided" - ) - continue + def close(self): + """Close file handle if open.""" + if self.file_handle: + self.file_handle.close() - logging.info("Loading stoi evaluation...") - from versa import stoi_metric - score_modules["stoi"] = {"module": stoi_metric} - logging.info("Initiate stoi evaluation successfully.") +class VersaScorer: + """Main scorer class that orchestrates the scoring process.""" - elif config["name"] == "estoi": - if not use_gt: - logging.warning( - "Cannot use estoi metric because no gt audio is provided" - ) - continue + def __init__(self, registry: MetricRegistry = None): + self.registry = registry or self._create_default_registry() + self.factory = MetricFactory(self.registry) + self.logger = logging.getLogger(self.__class__.__name__) - logging.info("Loading stoi evaluation...") - from versa import estoi_metric + def _create_default_registry(self) -> MetricRegistry: + """Create and populate the default metric registry.""" + registry = MetricRegistry() + # This would be populated by importing all metric modules + # and having them auto-register themselves + return registry - score_modules["estoi"] = {"module": estoi_metric} - logging.info("Initiate stoi evaluation successfully.") + def load_metrics( + self, + score_config: List[Dict[str, Any]], + use_gt: bool = True, + use_gt_text: bool = False, + use_gpu: bool = False, + ) -> MetricSuite: + """Load and configure metrics based on configuration.""" + metrics = {} - elif config["name"] == "visqol": - if not use_gt: - logging.warning( - "Cannot use visqol metric because no gt audio is provided" - ) - continue + for config in score_config: + metric_name = config["name"] - logging.info("Loading visqol evaluation...") try: - from versa import visqol_metric, visqol_setup - except ImportError: - logging.warning( - "VISQOL not installed, please check `tools` for installation guideline" - ) - continue - - api, fs = visqol_setup(model=config.get("model", "default")) - score_modules["visqol"] = { - "module": visqol_metric, - "args": {"api": api, "api_fs": fs}, - } - logging.info("Initiate visqol evaluation successfully.") - - elif config["name"] == "speaker": - if not use_gt: - logging.warning( - "Cannot use speaker metric because no gt audio is provided" - ) - continue - - logging.info("Loading speaker evaluation...") - from versa import speaker_metric, speaker_model_setup - - spk_model = speaker_model_setup( - model_tag=config.get("model_tag", "default"), - model_path=config.get("model_path", None), - model_config=config.get("model_config", None), - use_gpu=use_gpu, - ) - score_modules["speaker"] = { - "module": speaker_metric, - "args": {"model": spk_model}, - } - logging.info("Initiate speaker evaluation successfully.") - - elif config["name"] == "singer": - if not use_gt: - logging.warning( - "Cannot use singer metric because no gt audio is provided" - ) - continue - - logging.info("Loading singer evaluation...") - from versa import singer_metric, singer_model_setup - - singer_model = singer_model_setup( - model_name=config.get("model_name", "byol"), - model_path=config.get("model_path", None), - use_gpu=use_gpu, - torchscript=config.get("torchscript", False), - ) - - score_modules["singer"] = { - "module": singer_metric, - "args": {"model": singer_model}, - } - logging.info("Initiate singer evaluation successfully.") - - elif config["name"] == "sheet_ssqa": - - logging.info("Loading Sheet SSQA models for evaluation...") - from versa import sheet_ssqa, sheet_ssqa_setup - - sheet_model = sheet_ssqa_setup( - model_tag=config.get("model_tag", "default"), - model_path=config.get("model_path", None), - model_config=config.get("model_config", None), - use_gpu=use_gpu, - ) - score_modules["sheet_ssqa"] = { - "module": sheet_ssqa, - "args": {"model": sheet_model, "use_gpu": use_gpu}, - } - logging.info("Initiate Sheet SSQA evaluation successfully.") - - elif config["name"] == "squim_ref": - if not use_gt: - logging.warning("Cannot use squim_ref because no gt audio is provided") - continue - - logging.info("Loading squim metrics with reference") - from versa import squim_metric - - score_modules["squim_ref"] = { - "module": squim_metric, - } - logging.info("Initiate torch squim (with reference) successfully") - - elif config["name"] == "squim_no_ref": - - logging.info("Loading squim metrics with reference") - from versa import squim_metric_no_ref - - score_modules["squim_no_ref"] = { - "module": squim_metric_no_ref, - } - logging.info("Initiate torch squim (without reference) successfully") - - elif config["name"] == "espnet_wer": - if not use_gt_text: - logging.warning("Cannot use espnet_wer because no gt text is provided") - continue - - logging.info("Loading espnet_wer metric with reference text") - from versa import espnet_levenshtein_metric, espnet_wer_setup - - score_modules["espnet_wer"] = { - "module": espnet_levenshtein_metric, - "args": espnet_wer_setup( - model_tag=config.get("model_tag", "default"), - beam_size=config.get("beam_size", 1), - text_cleaner=config.get("text_cleaner", "whisper_basic"), - use_gpu=use_gpu, - ), - } - logging.info("Initiate ESPnet WER calculation successfully") - - elif config["name"] == "owsm_wer": - if not use_gt_text: - logging.warning("Cannot use owsm_wer because no gt text is provided") - continue - - logging.info("Loading owsm_wer metric with reference text") - from versa import owsm_levenshtein_metric, owsm_wer_setup - - score_modules["owsm_wer"] = { - "module": owsm_levenshtein_metric, - "args": owsm_wer_setup( - model_tag=config.get("model_tag", "default"), - beam_size=config.get("beam_size", 1), - text_cleaner=config.get("text_cleaner", "whisper_basic"), - use_gpu=use_gpu, - ), - } - logging.info("Initiate ESPnet-OWSM WER calculation successfully") - - elif config["name"] == "whisper_wer": - if not use_gt_text: - logging.warning("Cannot use whisper_wer because no gt text is provided") - continue - - logging.info("Loading whisper_wer metric with reference text") - from versa import whisper_levenshtein_metric, whisper_wer_setup - - # Load whisper model if it is already loaded - if ( - "speaking_rate" in score_modules.keys() - or "asr_matching" in score_modules.keys() - ): - args_cache = score_modules["speaking_rate"]["args"] - else: - args_cache = whisper_wer_setup( - model_tag=config.get("model_tag", "default"), - beam_size=config.get("beam_size", 1), - text_cleaner=config.get("text_cleaner", "whisper_basic"), - use_gpu=use_gpu, - ) - - score_modules["whisper_wer"] = { - "module": whisper_levenshtein_metric, - "args": args_cache, - } - logging.info("Initiate Whisper WER calculation successfully") - - elif config["name"] == "scoreq_ref": - if not use_gt: - logging.warning("Cannot use scoreq_ref because no gt audio is provided") - continue - - logging.info("Loading scoreq metrics with reference") - from versa import scoreq_ref, scoreq_ref_setup - - model = scoreq_ref_setup( - data_domain=config.get("data_domain", "synthetic"), - cache_dir=config.get("model_cache", "versa_cache/scoreq_pt-models"), - use_gpu=use_gpu, - ) - - score_modules["scoreq_ref"] = { - "module": scoreq_ref, - "model": model, - } - logging.info("Initiate scoreq (with reference) successfully") + # Check if metric requires ground truth + metadata = self.registry.get_metadata(metric_name) + if metadata and metadata.requires_reference and not use_gt: + self.logger.warning( + f"Cannot use {metric_name} because no ground truth is provided" + ) + continue - elif config["name"] == "scoreq_nr": - logging.info("Loading scoreq metrics without reference") - from versa import scoreq_nr, scoreq_nr_setup + if metadata and metadata.requires_text and not use_gt_text: + self.logger.warning( + f"Cannot use {metric_name} because no ground truth text is provided" + ) + continue - model = scoreq_nr_setup( - data_domain=config.get("data_domain", "synthetic"), - cache_dir=config.get("model_cache", "versa_cache/scoreq_pt-models"), - use_gpu=use_gpu, - ) + # Create metric instance + metric_config = {**config, "use_gpu": use_gpu} + metric = self.factory.create_metric(metric_name, metric_config) + metrics[metric_name] = metric - score_modules["scoreq_nr"] = { - "module": scoreq_nr, - "model": model, - } - logging.info("Initiate scoreq (with reference) successfully") + self.logger.info(f"Loaded {metric_name} successfully") - elif config["name"] == "nomad": - if not use_gt: - logging.warning("Cannot use nomad because no gt audio is provided") + except Exception as e: + self.logger.error(f"Failed to load metric {metric_name}: {e}") continue - logging.info("Loading nomad metrics with reference") - from versa import nomad, nomad_setup + return MetricSuite(metrics) - model = nomad_setup( - cache_dir=config.get("model_cache", "versa_cache/nomad_pt-models"), - use_gpu=use_gpu, - ) + def score_utterances( + self, + gen_files: Dict[str, str], + metric_suite: MetricSuite, + gt_files: Optional[Dict[str, str]] = None, + text_info: Optional[Dict[str, str]] = None, + output_file: Optional[str] = None, + io: str = "kaldi", + batch_size: int = 1, + ) -> List[Dict[str, Any]]: + """Score individual utterances.""" - score_modules["nomad"] = { - "module": nomad, - "model": model, - } - logging.info("Initiate nomad successfully") + processor = ScoreProcessor(metric_suite, output_file) + score_info = [] + cache_info = [] - elif config["name"] == "emo2vec_similarity": - if not use_gt: - logging.warning( - "Cannot use emo2vec_similarity metric because no gt audio is provided" + try: + for key in tqdm(gen_files.keys()): + # Step1: Load and validate generated audio + gen_sr, gen_wav = load_audio(gen_files[key], io) + gen_wav = wav_normalize(gen_wav) + + if not self._validate_audio(gen_wav, gen_sr, key, "generated"): + continue + + # Step2: Load and validate ground truth audio + gt_wav, gt_sr = None, None + if gt_files is not None: + if key not in gt_files: + self.logger.warning( + f"Ground truth not found for key {key}, skipping" + ) + continue + + gt_sr, gt_wav = load_audio(gt_files[key], io) + gt_wav = wav_normalize(gt_wav) + + if not self._validate_audio(gt_wav, gt_sr, key, "ground truth"): + continue + + # Step3: Load text information + text = text_info.get(key) if text_info else None + if text_info and key not in text_info: + self.logger.warning(f"Text not found for key {key}, skipping") + continue + + # Step4: Resample if needed + gen_wav, gt_wav, gen_sr = self._align_sample_rates( + gen_wav, gt_wav, gen_sr, gt_sr ) - continue - logging.info("Loading emo2vec metrics with reference") - from versa import emo2vec_setup, emo_sim + # Step5: Cache for batch processing + utterance_info = (key, gen_wav, gt_wav, gen_sr, text) + cache_info.append(utterance_info) - model = emo2vec_setup( - model_tag=config.get("model_tag", "default"), - model_path=config.get("model_path", None), - use_gpu=use_gpu, - ) + if len(cache_info) >= batch_size: + score_info.extend(processor.process_batch(cache_info)) + cache_info = [] - score_modules["emotion"] = { - "module": emo_sim, - "model": model, - } - logging.info("Initiate emo2vec successfully") - - elif config["name"] == "w2v2_dimensional_emotion": - from versa import w2v2_emo_dim_setup, w2v2_emo_dim_metric - - args_cache = w2v2_emo_dim_setup() - score_modules["w2v2_dimensional_emotion"] = { - "module": w2v2_emo_dim_metric, - "args": args_cache, - } - logging.info("Initiate w2v2_dimensional_emotion successfully") - - elif config["name"] == "se_snr": - logging.info("Loading se_snr metrics with reference") - from versa import se_snr, se_snr_setup - - model = se_snr_setup( - model_tag=config.get("model_tag", "default"), - model_path=config.get("model_path", None), - use_gpu=use_gpu, - ) + # Process remaining items + if cache_info: + score_info.extend(processor.process_batch(cache_info)) - score_modules["se_snr"] = { - "module": se_snr, - "model": model, - } - logging.info("Initiate se_snr successfully") - - elif config["name"] == "pam": - - logging.info("Loading pam metric without reference...") - from versa.utterance_metrics.pam import pam_metric, pam_model_setup - - pam_model = pam_model_setup(model_config=config, use_gpu=use_gpu) - score_modules["pam"] = { - "module": pam_metric, - "model": pam_model, - } - logging.info("Initiate pam metric successfully.") - elif config["name"] == "vad": - logging.info("Loading vad metric without reference...") - from versa.utterance_metrics.vad import vad_metric, vad_model_setup - - vad_model = vad_model_setup( - threshold=config.get("threshold", 0.5), - min_speech_duration_ms=config.get("min_speech_duration_ms", 250), - max_speech_duration_s=config.get("max_speech_duration_s", float("inf")), - min_silence_duration_ms=config.get("min_silence_duration_ms", 100), - speech_pad_ms=config.get("speech_pad_ms", 30), - ) - score_modules["vad"] = { - "module": vad_metric, - "args": vad_model, - } - logging.info("Initiate vad metric successfully.") - - elif config["name"] == "asvspoof_score": - - logging.info("Loading asvspoof score metric without reference...") - from versa.utterance_metrics.asvspoof_score import ( - asvspoof_metric, - deepfake_detection_model_setup, - ) + finally: + processor.close() - deepfake_detection_model = deepfake_detection_model_setup(use_gpu=use_gpu) - score_modules["asvspoof_score"] = { - "module": asvspoof_metric, - "model": deepfake_detection_model, - } - logging.info("Initiate asvspoof score metric successfully.") + self.logger.info(f"Scoring completed. Results saved to {output_file}") + return score_info - elif config["name"] == "pysepm": - if not use_gt: - logging.warning("Cannot use pysepm because no gt audio is provided") - continue + def score_corpus( + self, + gen_files: Dict[str, str], + metric_suite: MetricSuite, + base_files: Optional[Dict[str, str]] = None, + text_info: Optional[Dict[str, str]] = None, + output_file: Optional[str] = None, + ) -> Dict[str, Any]: + """Score at corpus level (e.g., FAD, KID).""" - logging.info("Loading pysepm metrics with reference") - from versa import pysepm_metric - - score_modules["pysepm"] = { - "module": pysepm_metric, - "args": { - "frame_len": config.get("frame_len", 0.03), - "overlap": config.get("overlap", 0.75), - }, - } - logging.info("Initiate pysepm successfully") - - elif config["name"] == "srmr": - logging.info("Loading srmr metrics with reference") - from versa import srmr_metric - - score_modules["srmr"] = { - "module": srmr_metric, - "args": { - "n_cochlear_filters": config.get("n_cochlear_filters", 23), - "low_freq": config.get("low_freq", 125), - "min_cf": config.get("min_cf", 128), - "max_cf": config.get("max_cf", 128), - "fast": config.get("fast", True), - "norm": config.get("norm", False), - }, - } - logging.info("Initiate srmr successfully") - - elif config["name"] == "noresqa": - if not use_gt: - logging.warning("Cannot use noresqa because no gt audio is provided") - continue + score_info = {} - logging.info("Loading noresqa metrics with reference") + # Filter for distributional metrics + distributional_metrics = metric_suite.filter_by_category( + MetricCategory.DISTRIBUTIONAL + ) - from versa.utterance_metrics.noresqa import ( - noresqa_metric, - noresqa_model_setup, - ) + for name, metric in distributional_metrics.metrics.items(): + try: + metadata = {"baseline_files": base_files, "text_info": text_info} - noresqa_model = noresqa_model_setup( - metric_type=config.get("metric_type", 0), - cache_dir=config.get("cache_dir", "versa_cache/noresqa_model"), - use_gpu=use_gpu, - ) - score_modules["noresqa"] = { - "module": noresqa_metric, - "args": { - "metric_type": config.get("metric_type", 0), - "model": noresqa_model, - }, - } - logging.info("Initiate noresqa score metric successfully.") - - elif config["name"] == "speaking_rate": - logging.info("Loading speaking rate metrics without reference") - from versa import speaking_rate_metric, speaking_rate_model_setup - - # Load whisper model if it is already loaded - if "whisper_wer" in score_modules.keys(): - speaking_rate_model = score_modules["whisper_wer"]["args"] - else: - speaking_rate_model = speaking_rate_model_setup( - model_tag=config.get("model_tag", "default"), - beam_size=config.get("beam_size", 1), - text_cleaner=config.get("text_cleaner", "whisper_basic"), - use_gpu=use_gpu, + score_result = metric.compute( + predictions=gen_files, references=base_files, metadata=metadata ) + score_info.update({name: score_result}) - score_modules["speaking_rate"] = { - "module": speaking_rate_metric, - "args": speaking_rate_model, - } - logging.info("Initiate speaking rate metric successfully.") - - elif config["name"] == "asr_match": - if not use_gt: - logging.warning("Cannot use asr_match because no gt audio is provided") - continue - - logging.info("Loading asr_match metric with reference text") - from versa import asr_match_metric, asr_match_setup - - # Load whisper model if it is already loaded - if "whisper_wer" in score_modules.keys(): - asr_model = score_modules["whisper_wer"]["args"] - elif "speaking_rate" in score_modules.keys(): - asr_model = score_modules["speaking_rate"]["args"] - else: - asr_model = asr_match_setup( - model_tag=config.get("model_tag", "default"), - beam_size=config.get("beam_size", 1), - text_cleaner=config.get("text_cleaner", "whisper_basic"), - use_gpu=use_gpu, - ) + except Exception as e: + self.logger.error(f"Error computing corpus metric {name}: {e}") - score_modules["asr_match"] = { - "module": asr_match_metric, - "args": asr_model, - } - logging.info("Initiate asr_match metric successfully") + if output_file: + with open(output_file, "w") as f: + yaml.dump(score_info, f) - elif config["name"] == "lid": - logging.info("Loading language identification metric") - from versa import language_id, owsm_lid_model_setup + return score_info - owsm_model = owsm_lid_model_setup( - model_tag=config.get("model_tag", "default"), - nbest=config.get("nbest", 3), - use_gpu=use_gpu, + def _validate_audio(self, wav: Any, sr: int, key: str, audio_type: str) -> bool: + """Validate audio data.""" + # Length check + if not check_minimum_length( + wav.shape[0] / sr, [] + ): # Metric names would be passed here + self.logger.warning( + f"Audio {key} ({audio_type}, length {wav.shape[0] / sr}) is too short, skipping" ) + return False - score_modules["lid"] = { - "module": language_id, - "args": owsm_model, - } - - elif config["name"] == "audiobox_aesthetics": - logging.info("Loading audiobox aesthetics metric") - from versa import audiobox_aesthetics_score, audiobox_aesthetics_setup - - audiobox_model = audiobox_aesthetics_setup( - model_path=config.get("model_path", None), - batch_size=config.get("batch_size", 1), - precision=config.get("precision", "bf16"), - cache_dir=config.get("cache_dir", "versa_cache/audiobox"), - use_huggingface=config.get("use_huggingface", True), - use_gpu=use_gpu, + # Check for silent audio + if check_all_same(wav): + self.logger.warning( + f"Audio {key} ({audio_type}) has only the same value, skipping" ) + return False - score_modules["audiobox_aesthetics"] = { - "module": audiobox_aesthetics_score, - "args": {"model": audiobox_model}, - } - logging.info("Initiate audiobox aesthetics metric successfully") + return True - elif config["name"] == "cdpam": - if not use_gt: - logging.warning( - "Cannot use cdpam metrics because no gt audio is provided" - ) - continue - logging.info("Loading cdpam evaluation...") - from versa import cdpam_metric, cdpam_model_setup - - cdpam_model = cdpam_model_setup(use_gpu=use_gpu) - score_modules["cdpam"] = { - "module": cdpam_metric, - "args": {"model": cdpam_model}, - } - logging.info("Initiate cdpam evaluation successfully.") - - elif config["name"] == "dpam": - if not use_gt: - logging.warning( - "Cannot use dpam metrics because no gt audio is provided" - ) - continue - logging.info("Loading dpam evaluation...") - from versa import dpam_metric, dpam_model_setup - - dpam_model = dpam_model_setup(use_gpu=use_gpu) - score_modules["dpam"] = { - "module": dpam_metric, - "args": {"model": dpam_model}, - } - logging.info("Initiate dpam evaluation successfully.") - - elif "qwen_omni" in config["name"]: - logging.info("Loading qwen omni model") - from versa import qwen_omni_model_setup - - if "qwen_omni" not in score_modules.keys(): - qwen_omni_model = qwen_omni_model_setup( - model_tag=config.get("model_tag", "default"), - ) - score_modules["qwen_omni"] = { - "module": qwen_omni_model, - "start_prompt": config.get("start_prompt", None), - } - - if config["name"] == "qwen_omni_singing_technique": - from versa import qwen_omni_singing_technique_metric - - score_modules["qwen_omni_singing_technique"] = { - "module": qwen_omni_singing_technique_metric, - "prompt": config.get("prompt", None), - } - # To add qwen-omni modules for others - - elif "qwen2_audio" in config["name"]: - logging.info("Loading qwen2-audio model") - from versa import qwen2_model_setup - - if "qwen2_audio" not in score_modules.keys(): - qwen_model = qwen2_model_setup( - model_tag=config.get("model_tag", "default"), - ) - score_modules["qwen2_audio"] = { - "module": qwen_model, - "start_prompt": config.get("start_prompt", None), - } - - # 1. Speaker Characteristics - if config["name"] == "qwen2_audio_speaker_count": - from versa import qwen2_speaker_count_metric + def _align_sample_rates( + self, gen_wav: Any, gt_wav: Any, gen_sr: int, gt_sr: Optional[int] + ) -> tuple: + """Align sample rates between generated and ground truth audio.""" + if gt_sr is None: + return gen_wav, gt_wav, gen_sr - score_modules["qwen2_audio_speaker_count"] = { - "module": qwen2_speaker_count_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speaker_gender": - from versa import qwen2_speaker_gender_metric - - score_modules["qwen2_audio_speaker_gender"] = { - "module": qwen2_speaker_gender_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speaker_age": - from versa import qwen2_speaker_age_metric - - score_modules["qwen2_audio_speaker_age"] = { - "module": qwen2_speaker_age_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speech_impairment": - from versa import qwen2_speech_impairment_metric - - score_modules["qwen2_audio_speech_impairment"] = { - "module": qwen2_speech_impairment_metric, - "prompt": config.get("prompt", None), - } - - # 2. Voice Properties - elif config["name"] == "qwen2_audio_voice_pitch": - from versa import qwen2_voice_pitch_metric - - score_modules["qwen2_audio_voice_pitch"] = { - "module": qwen2_voice_pitch_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_pitch_range": - from versa import qwen2_pitch_range_metric - - score_modules["qwen2_audio_pitch_range"] = { - "module": qwen2_pitch_range_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_voice_type": - from versa import qwen2_voice_type_metric - - score_modules["qwen2_audio_voice_type"] = { - "module": qwen2_voice_type_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speech_volume_level": - from versa import qwen2_speech_volume_level_metric - - score_modules["qwen2_audio_speech_volume_level"] = { - "module": qwen2_speech_volume_level_metric, - "prompt": config.get("prompt", None), - } - - # 3. Speech Content - elif config["name"] == "qwen2_audio_language": - from versa import qwen2_language_metric - - score_modules["qwen2_audio_language"] = { - "module": qwen2_language_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speech_register": - from versa import qwen2_speech_register_metric - - score_modules["qwen2_audio_speech_register"] = { - "module": qwen2_speech_register_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_vocabulary_complexity": - from versa import qwen2_vocabulary_complexity_metric - - score_modules["qwen2_audio_vocabulary_complexity"] = { - "module": qwen2_vocabulary_complexity_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speech_purpose": - from versa import qwen2_speech_purpose_metric - - score_modules["qwen2_audio_speech_purpose"] = { - "module": qwen2_speech_purpose_metric, - "prompt": config.get("prompt", None), - } - - # 4. Speech Delivery - elif config["name"] == "qwen2_audio_speech_emotion": - from versa import qwen2_speech_emotion_metric - - score_modules["qwen2_audio_speech_emotion"] = { - "module": qwen2_speech_emotion_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speech_clarity": - from versa import qwen2_speech_clarity_metric - - score_modules["qwen2_audio_speech_clarity"] = { - "module": qwen2_speech_clarity_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speech_rate": - from versa import qwen2_speech_rate_metric - - score_modules["qwen2_audio_speech_rate"] = { - "module": qwen2_speech_rate_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speaking_style": - from versa import qwen2_speaking_style_metric - - score_modules["qwen2_audio_speaking_style"] = { - "module": qwen2_speaking_style_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_laughter_crying": - from versa import qwen2_laughter_crying_metric - - score_modules["qwen2_audio_laughter_crying"] = { - "module": qwen2_laughter_crying_metric, - "prompt": config.get("prompt", None), - } - - # 5. Interaction Patterns - elif config["name"] == "qwen2_audio_overlapping_speech": - from versa import qwen2_overlapping_speech_metric - - score_modules["qwen2_audio_overlapping_speech"] = { - "module": qwen2_overlapping_speech_metric, - "prompt": config.get("prompt", None), - } - - # 6. Recording Environment - elif config["name"] == "qwen2_audio_speech_background_environment": - from versa import qwen2_speech_background_environment_metric - - score_modules["qwen2_audio_speech_background_environment"] = { - "module": qwen2_speech_background_environment_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_recording_quality": - from versa import qwen2_recording_quality_metric - - score_modules["qwen2_audio_recording_quality"] = { - "module": qwen2_recording_quality_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_channel_type": - from versa import qwen2_channel_type_metric - - score_modules["qwen2_audio_channel_type"] = { - "module": qwen2_channel_type_metric, - "prompt": config.get("prompt", None), - } - - # 7. Vocal Evaluation - elif config["name"] == "qwen2_audio_singing_technique": - from versa import qwen2_singing_technique_metric - - score_modules["qwen2_audio_singing_technique"] = { - "module": qwen2_singing_technique_metric, - "prompt": config.get("prompt", None), - } - - logging.info( - "Initiate qwen2 audio metric: {} successfully".format(config["name"]) - ) - elif "chroma_alignment" in config["name"]: - from versa import chroma_metric - - score_modules["chroma_alignment"] = { - "module": chroma_metric, - "args": { - "scale_factor": config.get("scale_factor", 100), - }, - } - elif "vqscore" in config["name"]: - logging.info("Loading VQScore model") - from versa import vqscore_metric, vqscore_setup - - vqscore_model = vqscore_setup(use_gpu=use_gpu) - score_modules["vqscore"] = { - "module": vqscore_metric, - "args": {"model": vqscore_model}, - } - logging.info("Initiate VQScore evaluation successfully.") - return score_modules - - -def process_cache_info(cache_info, score_modules, output_file): - batch_score_info = [] - for utt_info in cache_info: - key, gen_wav, gt_wav, gen_sr, text = utt_info - utt_score = {"key": key} - # try: - # utt_score.update( - # use_score_modules(score_modules, gen_wav, gt_wav, gen_sr, text) - # ) - # except Exception as e: - # print("error processing file: {} with error {}".format(key, e)) - utt_score.update( - use_score_modules(score_modules, gen_wav, gt_wav, gen_sr, text) - ) - batch_score_info.append(utt_score) - if output_file is not None: - printable_result = json.dumps(utt_score, default=default_numpy_serializer) - output_file.write(f"{printable_result}\n") - return batch_score_info - - -def use_score_modules(score_modules, gen_wav, gt_wav, gen_sr, text=None): - utt_score = {} - - # general cache information to reduce recaculation - general_cache = { - "whisper_hyp_text": None, - } - for key in score_modules.keys(): - if key == "mcd_f0" or key == "chroma_alignment": - score = score_modules[key]["module"]( - gen_wav, gt_wav, gen_sr, **score_modules[key]["args"] - ) - elif key == "signal_metric": - try: - score = score_modules[key]["module"](gen_wav, gt_wav) - except ValueError as e: - logging.warning( - "Value error in signal metric. Usually due to silence audio: {}".format( - e - ) - ) - continue - elif key == "warpq": - score = score_modules[key]["module"]( - score_modules[key]["model"], gen_wav, gt_wav, gen_sr - ) - elif key == "nisqa": - try: - score = score_modules[key]["module"]( - score_modules[key]["model"], - gen_wav, - gen_sr, - ) - except ValueError as e: - logging.warning( - "Value error in NISQA metric. Usually due to silence audio: {}".format( - e - ) - ) - continue - elif key == "discrete_speech": - score = score_modules[key]["module"]( - score_modules[key]["model"], - gen_wav, - gt_wav, - gen_sr, - ) - elif key == "pseudo_mos": - score = score_modules[key]["module"]( - gen_wav, gen_sr, **score_modules[key]["args"] - ) - elif key in ["pesq", "stoi", "estoi"]: - score = score_modules[key]["module"](gen_wav, gt_wav, gen_sr) - elif key == "visqol": - score = score_modules[key]["module"]( - score_modules[key]["args"]["api"], - score_modules[key]["args"]["api_fs"], - gen_wav, - gt_wav, - gen_sr, - ) - elif key == "speaker" or key == "singer": - score = score_modules[key]["module"]( - score_modules[key]["args"]["model"], gen_wav, gt_wav, gen_sr - ) - elif key == "sheet_ssqa": - score = score_modules[key]["module"]( - score_modules[key]["args"]["model"], - gen_wav, - gen_sr, - use_gpu=score_modules[key]["args"]["use_gpu"], - ) - elif key == "squim_ref": - score = score_modules[key]["module"](gen_wav, gt_wav, gen_sr) - elif key == "squim_no_ref": - score = score_modules[key]["module"](gen_wav, gen_sr) - elif key == "nomad": - score = score_modules[key]["module"]( - score_modules[key]["model"], gen_wav, gt_wav, gen_sr - ) - elif key == "espnet_wer" or key == "owsm_wer" or key == "whisper_wer": - score = score_modules[key]["module"]( - score_modules[key]["args"], - gen_wav, - text, - gen_sr, - ) - if key == "whisper_wer": - general_cache["whisper_hyp_text"] = score["whisper_hyp_text"] - elif key in ["scoreq_ref", "emotion"]: - score = score_modules[key]["module"]( - score_modules[key]["model"], gen_wav, gt_wav, gen_sr - ) - elif key in ["scoreq_nr", "se_snr"]: - score = score_modules[key]["module"]( - score_modules[key]["model"], gen_wav, gen_sr - ) - elif key in ["pam", "asvspoof_score"]: - score = score_modules[key]["module"]( - score_modules[key]["model"], gen_wav, fs=gen_sr - ) - elif key in ["vad", "lid", "w2v2_dimensional_emotion"]: - score = score_modules[key]["module"]( - score_modules[key]["args"], - gen_wav, - gen_sr, - ) - elif key == "pysepm": - score = score_modules[key]["module"](gen_wav, gt_wav, fs=gen_sr) - elif key == "srmr": - score = score_modules[key]["module"](gen_wav, fs=gen_sr) - elif key == "noresqa": - score = score_modules[key]["module"]( - score_modules[key]["args"]["model"], - gen_wav, - gt_wav, - fs=gen_sr, - metric_type=score_modules[key]["args"]["metric_type"], - ) - elif key == "speaking_rate": - cache_text = None - if general_cache.get("whisper_hyp_text", None) is not None: - cache_text = utt_score["whisper_hyp_text"] - score = score_modules[key]["module"]( - score_modules[key]["args"], - gen_wav, - cache_text, - gen_sr, - ) - if cache_text is None: - general_cache["whisper_hyp_text"] = score["whisper_hyp_text"] - elif key == "asr_match": - cache_text = None - if general_cache.get("whisper_hyp_text", None) is not None: - cache_text = utt_score["whisper_hyp_text"] - score = score_modules[key]["module"]( - score_modules[key]["args"], - gen_wav, - gt_wav, - cache_text, - gen_sr, - ) - if cache_text is None: - general_cache["whisper_hyp_text"] = score["whisper_hyp_text"] - elif key == "audiobox_aesthetics": - score = score_modules[key]["module"]( - score_modules[key]["args"]["model"], - gen_wav, - gen_sr, - ) - elif key == "dpam" or key == "cdpam": - score = score_modules[key]["module"]( - score_modules[key]["args"]["model"], - gen_wav, - gt_wav, - gen_sr, - ) - - elif "qwen2_audio" in key: - if key == "qwen2_audio": - continue # skip the base model, only use the specific metrics - # Support qwen2_audio metrics - score = score_modules[key]["module"]( - score_modules["qwen2_audio"]["module"], - gen_wav, - gen_sr, - custom_prompt=score_modules[key]["prompt"], - ) - elif "qwen_omni" in key: - if key == "qwen_omni": - continue - score = score_modules[key]["module"]( - score_modules["qwen_omni"]["module"], - gen_wav, - gen_sr, - custom_prompt=score_modules[key]["prompt"], - ) - elif key == "vqscore": - score = score_modules[key]["module"]( - score_modules[key]["args"]["model"], gen_wav, gen_sr - ) - else: - raise NotImplementedError(f"Not supported {key}") - - logging.info(f"Score for {key} is {score}") - utt_score.update(score) - return utt_score - - -def list_scoring( - gen_files, - score_modules, - gt_files=None, - text_info=None, - output_file=None, - io="kaldi", - batch_size=1, -): - if output_file is not None: - f = open(output_file, "w", encoding="utf-8") - else: - f = None - - score_info = [] - cache_info = [] # for batch processing - for key in tqdm(gen_files.keys()): - try: - # Step1: load source speech and conduct basic checks - gen_sr, gen_wav = load_audio(gen_files[key], io) - gen_wav = wav_normalize(gen_wav) - except Exception as e: - print(f"Error loading audio file for key '{key}': {gen_files[key]}") - print(f"Error details: {e}") - continue # Skip this file and move to the next one - - # length check - if not check_minimum_length(gen_wav.shape[0] / gen_sr, score_modules.keys()): - logging.warning( - "audio {} (generated, length {}) is too short to be evaluated with some metric metrics, skipping".format( - key, gen_wav.shape[0] / gen_sr - ) - ) - continue - - # Step2: load reference (gt) speech and conduct basic checks - if gt_files is not None: - if key not in gen_files.keys(): - logging.warning( - "key {} not found in ground truth files though provided, skipping".format( - key - ) - ) - continue - - gt_sr, gt_wav = load_audio(gt_files[key], io) - gt_wav = wav_normalize(gt_wav) - - # check ground truth audio files - if check_all_same(gt_wav): - logging.warning( - "gt audio of key {} has only the same value, skipping".format(key) - ) - continue - - # length check - if not check_minimum_length(gt_wav.shape[0] / gt_sr, score_modules.keys()): - logging.warning( - "audio {} (ground truth, length {}) is too short to be evaluated with many metrics, skipping".format( - key, gt_wav.shape[0] / gt_sr - ) - ) - continue - else: - gt_wav = None - gt_sr = None - - # Step3: load text information if provided - text = None - if text_info is not None: - if key not in text_info.keys(): - logging.warning( - "key {} not found in ground truth transcription though provided, skipping".format( - key - ) - ) - continue - else: - text = text_info[key] - - # Step4: check if the sampling rate of generated and gt audio are the same - if gt_sr is not None and gen_sr > gt_sr: - logging.warning( - "Resampling the generated audio to match the ground truth audio" - ) + if gen_sr > gt_sr: + self.logger.warning("Resampling generated audio to match ground truth") gen_wav = librosa.resample(gen_wav, orig_sr=gen_sr, target_sr=gt_sr) gen_sr = gt_sr - elif gt_sr is not None and gen_sr < gt_sr: - logging.warning( - "Resampling the ground truth audio to match the generated audio" + elif gen_sr < gt_sr: + self.logger.warning( + "Resampling ground truth audio to match generated audio" ) gt_wav = librosa.resample(gt_wav, orig_sr=gt_sr, target_sr=gen_sr) - # Step5: cache for batch processing - utterance_info = (key, gen_wav, gt_wav, gen_sr, text) + return gen_wav, gt_wav, gen_sr - cache_info.append(utterance_info) - if len(cache_info) == batch_size: - # Process after a batch is collected - score_info.extend(process_cache_info(cache_info, score_modules, f)) - cache_info = [] - else: - # continue collect the batch - continue - # Process left-over batch - score_info.extend(process_cache_info(cache_info, score_modules, f)) +def compute_summary(score_info: List[Dict[str, Any]]) -> Dict[str, Any]: + """Compute summary statistics from individual scores.""" + if not score_info: + return {} - logging.info("Scoring completed and save score at {}".format(output_file)) - return score_info - - -def load_summary(score_info): summary = {} for key in score_info[0].keys(): - if key in STR_METRIC or key == "key": - # NOTE(jiatong): skip text cases + if key not in NUM_METRIC: continue - summary[key] = sum([score[key] for score in score_info]) - if "_wer" not in key and "_cer" not in key: - # Average for non-WER/CER metrics - summary[key] /= len(score_info) - return summary - - -def load_corpus_modules( - score_config, cache_folder="versa_cache", use_gpu=False, io="kaldi" -): - score_modules = {} - for config in score_config: - if config["name"] == "fad": - logging.info("Loading FAD evaluation with specific models...") - # TODO(jiatong): fad will automatically use cuda if detected - # need to sync to the same space - from versa import fad_scoring, fad_setup - - fad_info = fad_setup( - fad_embedding=config.get("fad_embedding", "default"), - baseline=config.get("baseline_audio", "missing"), - cache_dir=config.get("cache_dir", cache_folder), - use_inf=config.get("use_inf", False), - io=io, - ) - fad_key = "fad_{}".format(config.get("model", "default")) - - score_modules[fad_key] = { - "module": fad_scoring, - "args": fad_info, - } - logging.info( - "Initiate {} calculation evaluation successfully.".format(fad_key) - ) - elif config["name"] == "kid": - logging.info("Loading KID evaluation with specific models...") - from versa import kid_scoring, kid_setup - - kid_info = kid_setup( - model_tag=config.get("model_tag", "default"), - model_path=config.get("model_path", None), - model_config=config.get("model_config", None), - use_gpu=use_gpu, - ) - kid_key = "kid_{}".format(config.get("model", "default")) - score_modules[kid_key] = { - "module": kid_scoring, - "args": kid_info, - } - logging.info( - "Initiate {} calculation evaluation successfully.".format(kid_key) - ) + values = [ + score[key] + for score in score_info + if key in score and score[key] is not None + ] + if not values: + continue - return score_modules - - -def corpus_scoring( - gen_files, - score_modules, - base_files=None, - text_info=None, - output_file=None, -): - score_info = {} - for key in score_modules.keys(): - if key.startswith("fad"): - fad_info = score_modules[key]["args"] - if base_files is not None: - fad_info["baseline"] = base_files - elif fad_info["baseline"] == "missing": - raise ValueError("Baseline audio not provided for FAD") - score_result = score_modules[key]["module"]( - gen_files, fad_info, key_info=key - ) - elif key.startswith("kld"): - kid_info = score_modules[key]["args"] - if base_files is not None: - kid_info["baseline"] = base_files - elif kid_info["baseline"] == "missing": - raise ValueError("Baseline audio not provided for FAD") - score_result = score_modules[key]["module"]( - gen_files, kid_info, key_info=key - ) - else: - raise NotImplementedError("Not supported {}".format(key)) - score_info.update(score_result) + summary[key] = sum(values) + if "_wer" not in key and "_cer" not in key: + summary[key] /= len(values) - if output_file is not None: - with open(output_file, "w") as f: - yaml.dump(score_info, f) - return score_info + return summary diff --git a/versa/utterance_metrics/asr_matching.py b/versa/utterance_metrics/asr_matching.py index d77797f..53d95c6 100644 --- a/versa/utterance_metrics/asr_matching.py +++ b/versa/utterance_metrics/asr_matching.py @@ -27,6 +27,7 @@ WHISPER_AVAILABLE = False from espnet2.text.cleaner import TextCleaner +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType # Constants TARGET_FS = 16000 @@ -39,221 +40,185 @@ class WhisperNotAvailableError(RuntimeError): pass -def asr_match_setup( - model_tag: str = "default", - beam_size: int = 5, - text_cleaner: str = "whisper_basic", - use_gpu: bool = True, -) -> Dict[str, Any]: +def is_whisper_available(): """ - Set up ASR matching utilities. - - Args: - model_tag: Whisper model tag. Options include "tiny", "base", "small", - "medium", "large", or "large-v2". Defaults to "large". - beam_size: Beam size for decoding. - text_cleaner: Text cleaner type for post-processing. - use_gpu: Whether to use GPU for computation. + Check if the Whisper package is available. Returns: - Dictionary containing the model, text cleaner, and beam size. - - Raises: - WhisperNotAvailableError: If Whisper is not installed but is required. - RuntimeError: If model loading fails. - """ - if not WHISPER_AVAILABLE: - raise WhisperNotAvailableError( - "Whisper WER is used for evaluation while openai-whisper is not installed" - ) - - # Use the large model by default - if model_tag == "default": - model_tag = "large" - - # Set device based on availability and user preference - device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" - - try: - # Load the Whisper model - logger.info(f"Loading Whisper model '{model_tag}' on {device}") - model = whisper.load_model(model_tag, device=device) - - # Initialize text cleaner - textcleaner = TextCleaner(text_cleaner) - - # Return utilities dictionary - return {"model": model, "cleaner": textcleaner, "beam_size": beam_size} - except Exception as e: - raise RuntimeError(f"Failed to initialize Whisper model: {str(e)}") from e - - -def asr_match_metric( - wer_utils: Dict[str, Any], - pred_x: np.ndarray, - gt_x: np.ndarray, - cache_pred_text: Optional[str] = None, - fs: int = 16000, -) -> Dict[str, Union[float, str]]: + bool: True if Whisper is available, False otherwise. """ - Calculate the ASR match error rate and related metrics. - - This function compares the ASR transcription of the predicted audio - with the transcription of the ground truth audio to compute character-level - edit distance metrics. - - Args: - wer_utils: A utility dict for WER calculation including: - - model: whisper model - - cleaner: text cleaner - - beam_size: beam size for decoding - pred_x: Predicted/test signal as a numpy array (time,) - gt_x: Ground truth signal as a numpy array (time,) - cache_pred_text: Optional pre-computed transcription for pred_x - fs: Sampling rate of the input audio in Hz - - Returns: - Dictionary containing: - - asr_match_error_rate: The character error rate - - whisper_hyp_text: The transcription of the predicted audio + return WHISPER_AVAILABLE - Raises: - ValueError: If input data is invalid - RuntimeError: If transcription fails - """ - # Validate inputs - if pred_x is None or gt_x is None: - raise ValueError("Both predicted and ground truth signals must be provided") - # Make sure inputs are numpy arrays - pred_x = np.asarray(pred_x) - gt_x = np.asarray(gt_x) +class ASRMatchMetric(BaseMetric): + """ASR-oriented Mismatch Error Rate (ASR-Match) metric using Whisper.""" - # Process the speech to be evaluated - if cache_pred_text is not None: - inf_text = cache_pred_text - else: + def _setup(self): + if not WHISPER_AVAILABLE: + raise ImportError( + "Whisper is not properly installed. Please install following https://github.com/openai/whisper" + ) + self.model_tag = self.config.get("model_tag", "default") + self.beam_size = self.config.get("beam_size", 5) + self.text_cleaner = self.config.get("text_cleaner", "whisper_basic") + self.use_gpu = self.config.get("use_gpu", True) + # Use the large model by default + if self.model_tag == "default": + self.model_tag = "large" + self.device = "cuda" if self.use_gpu and torch.cuda.is_available() else "cpu" + try: + self.model = whisper.load_model(self.model_tag, device=self.device) + self.cleaner = TextCleaner(self.text_cleaner) + except Exception as e: + raise RuntimeError(f"Failed to initialize Whisper model: {str(e)}") from e + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + pred_x = predictions + gt_x = references + fs = 16000 + cache_pred_text = None + if metadata is not None: + fs = metadata.get("sample_rate", 16000) + cache_pred_text = metadata.get("cache_pred_text", None) + # Validate inputs + if pred_x is None or gt_x is None: + raise ValueError("Both predicted and ground truth signals must be provided") + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + # Process the speech to be evaluated + if cache_pred_text is not None: + inf_text = cache_pred_text + else: + try: + if fs != TARGET_FS: + pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=TARGET_FS) + with torch.no_grad(): + transcription = self.model.transcribe( + torch.tensor(pred_x).float(), beam_size=self.beam_size + ) + inf_text = transcription["text"] + except Exception as e: + raise RuntimeError( + f"Failed to transcribe predicted signal: {str(e)}" + ) from e + # Process the ground truth speech try: - # Resample if necessary if fs != TARGET_FS: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=TARGET_FS) - - # Convert to tensor and transcribe + gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=TARGET_FS) with torch.no_grad(): - transcription = wer_utils["model"].transcribe( - torch.tensor(pred_x).float(), beam_size=wer_utils["beam_size"] + transcription = self.model.transcribe( + torch.tensor(gt_x).float(), beam_size=self.beam_size ) - inf_text = transcription["text"] + gt_text = transcription["text"] except Exception as e: raise RuntimeError( - f"Failed to transcribe predicted signal: {str(e)}" + f"Failed to transcribe ground truth signal: {str(e)}" ) from e - - # Process the ground truth speech - try: - # Resample if necessary - if fs != TARGET_FS: - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=TARGET_FS) - - # Convert to tensor and transcribe - with torch.no_grad(): - transcription = wer_utils["model"].transcribe( - torch.tensor(gt_x).float(), beam_size=wer_utils["beam_size"] + ref_text = self.cleaner(gt_text) + pred_text = self.cleaner(inf_text) + ref_chars = list(ref_text) + pred_chars = list(pred_text) + result = { + "asr_match_delete": 0, + "asr_match_insert": 0, + "asr_match_replace": 0, + "asr_match_equal": 0, + } + for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_chars, pred_chars): + if op == "insert": + result["asr_match_" + op] += inf_et - inf_st + else: + result["asr_match_" + op] += ref_et - ref_st + total_ref = ( + result["asr_match_delete"] + + result["asr_match_replace"] + + result["asr_match_equal"] + ) + if total_ref != len(ref_chars): + logger.warning( + f"Reference operation count mismatch: {total_ref} vs {len(ref_chars)}" ) - gt_text = transcription["text"] - except Exception as e: - raise RuntimeError(f"Failed to transcribe ground truth signal: {str(e)}") from e - - # Clean the text using the provided cleaner - ref_text = wer_utils["cleaner"](gt_text) - pred_text = wer_utils["cleaner"](inf_text) - - # Convert texts to character lists for edit distance calculation - ref_chars = list(ref_text) - pred_chars = list(pred_text) - - # Initialize result dictionary with operation counts - result = { - "asr_match_delete": 0, # Deletions: chars in reference but not in prediction - "asr_match_insert": 0, # Insertions: chars in prediction but not in reference - "asr_match_replace": 0, # Substitutions: chars that differ between ref and pred - "asr_match_equal": 0, # Matches: chars that are the same in ref and pred - } - - # Calculate edit operations using Levenshtein - for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_chars, pred_chars): - if op == "insert": - result["asr_match_" + op] += inf_et - inf_st - else: - result["asr_match_" + op] += ref_et - ref_st - - # Validate operation counts - total_ref = ( - result["asr_match_delete"] - + result["asr_match_replace"] - + result["asr_match_equal"] - ) - if total_ref != len(ref_chars): - logger.warning( - f"Reference operation count mismatch: {total_ref} vs {len(ref_chars)}" + total_pred = ( + result["asr_match_insert"] + + result["asr_match_replace"] + + result["asr_match_equal"] ) - - total_pred = ( - result["asr_match_insert"] - + result["asr_match_replace"] - + result["asr_match_equal"] - ) - if total_pred != len(pred_chars): - logger.warning( - f"Prediction operation count mismatch: {total_pred} vs {len(pred_chars)}" + if total_pred != len(pred_chars): + logger.warning( + f"Prediction operation count mismatch: {total_pred} vs {len(pred_chars)}" + ) + if len(ref_chars) == 0: + asr_match_error_rate = 1.0 + logger.warning("Reference text is empty, setting error rate to 1.0") + else: + asr_match_error_rate = ( + result["asr_match_delete"] + + result["asr_match_insert"] + + result["asr_match_replace"] + ) / len(ref_chars) + return { + "asr_match_error_rate": asr_match_error_rate, + "whisper_hyp_text": inf_text, + "ref_text_length": len(ref_chars), + "pred_text_length": len(pred_chars), + "match_details": result, + } + + def get_metadata(self) -> MetricMetadata: + return MetricMetadata( + name="asr_match", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["whisper", "espnet2", "Levenshtein", "librosa", "torch"], + description="ASR-oriented Mismatch Error Rate (ASR-Match) using Whisper for reference-based speech evaluation.", + paper_reference=None, + implementation_source="https://github.com/ftshijt/versa", ) - # Calculate error rate - if len(ref_chars) == 0: - # Handle empty reference case - asr_match_error_rate = 1.0 - logger.warning("Reference text is empty, setting error rate to 1.0") - else: - # Calculate character error rate - asr_match_error_rate = ( - result["asr_match_delete"] - + result["asr_match_insert"] - + result["asr_match_replace"] - ) / len(ref_chars) - - # Return results - return { - "asr_match_error_rate": asr_match_error_rate, - "whisper_hyp_text": inf_text, - # Additional metrics that might be useful - "ref_text_length": len(ref_chars), - "pred_text_length": len(pred_chars), - "match_details": result, - } - - -def is_whisper_available(): - """ - Check if the Whisper package is available. - Returns: - bool: True if Whisper is available, False otherwise. - """ - return WHISPER_AVAILABLE +def register_asr_match_metric(registry): + """Register ASR-Match metric with the registry.""" + metric_metadata = MetricMetadata( + name="asr_match", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["whisper", "espnet2", "Levenshtein", "librosa", "torch"], + description="ASR-oriented Mismatch Error Rate (ASR-Match) using Whisper for reference-based speech evaluation.", + paper_reference=None, + implementation_source="https://github.com/ftshijt/versa", + ) + registry.register( + ASRMatchMetric, metric_metadata, aliases=["ASRMatch", "asr_match_error_rate"] + ) if __name__ == "__main__": - # Example usage + # Example usage for the class-based metric try: # Generate random test audio (1 second at 16kHz) test_audio = np.random.random(TARGET_FS) - - # Set up ASR matching utilities - wer_utils = asr_match_setup(model_tag="tiny", use_gpu=torch.cuda.is_available()) - + # Set up ASR matching metric + config = { + "model_tag": "tiny", + "beam_size": 1, + "text_cleaner": "whisper_basic", + "use_gpu": torch.cuda.is_available(), + } + metric = ASRMatchMetric(config) # Calculate metrics - metrics = asr_match_metric(wer_utils, test_audio, test_audio, None, TARGET_FS) - + metrics = metric.compute( + test_audio, test_audio, metadata={"sample_rate": TARGET_FS} + ) # Print results print(f"ASR Match Error Rate: {metrics['asr_match_error_rate']:.4f}") print(f"Transcription: '{metrics['whisper_hyp_text']}'") diff --git a/versa/utterance_metrics/asvspoof_score.py b/versa/utterance_metrics/asvspoof_score.py index 473c184..e248187 100644 --- a/versa/utterance_metrics/asvspoof_score.py +++ b/versa/utterance_metrics/asvspoof_score.py @@ -13,79 +13,179 @@ """ import json +import logging import os import sys +from typing import Dict, Any, Optional, Union import librosa import numpy as np import torch -sys.path.append("./tools/checkpoints/aasist") -from models.AASIST import Model as AASIST # noqa: E402 +logger = logging.getLogger(__name__) +# Handle optional AASIST dependency +try: + sys.path.append("./tools/checkpoints/aasist") + from models.AASIST import Model as AASIST # noqa: E402 -def deepfake_detection_model_setup( - model_tag="default", model_path=None, model_config=None, use_gpu=False -): - """Setup deepfake detection model. + AASIST_AVAILABLE = True +except ImportError: + logger.warning( + "AASIST is not properly installed. " + "Please install following https://github.com/clovaai/aasist" + ) + AASIST = None + AASIST_AVAILABLE = False - Args: - model_tag (str): Model tag. Defaults to "default". - model_path (str, optional): Path to model weights. Defaults to None. - model_config (str, optional): Path to model config. Defaults to None. - use_gpu (bool, optional): Whether to use GPU. Defaults to False. +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType - Returns: - AASIST: The loaded model. - """ - device = "cuda" if use_gpu else "cpu" - - if model_path is not None and model_config is not None: - with open(model_config, "r") as f_json: - config = json.loads(f_json.read()) - model = AASIST(config["model_config"]).to(device) - model.load_state_dict(torch.load(model_path, map_location=device)) - else: - if model_tag == "default": - model_root = "./tools/checkpoints/aasist" - model_config = os.path.join(model_root, "config/AASIST.conf") - model_path = os.path.join(model_root, "models/weights/AASIST.pth") - - with open(model_config, "r") as f_json: - config = json.loads(f_json.read()) - model = AASIST(config["model_config"]).to(device) - model.load_state_dict(torch.load(model_path, map_location=device)) - else: - raise NotImplementedError - model.device = device - return model +class AASISTNotAvailableError(RuntimeError): + """Exception raised when AASIST is required but not available.""" -def asvspoof_metric(model, pred_x, fs): - """Calculate ASVspoof score for audio. + pass - Args: - model (AASIST): The loaded deepfake detection model. - pred_x (np.ndarray): Audio signal. - fs (int): Sampling rate. + +def is_aasist_available(): + """ + Check if the AASIST package is available. Returns: - dict: Dictionary containing the ASVspoof score. + bool: True if AASIST is available, False otherwise. """ - # NOTE(jiatong): only work for 16000 Hz - if fs != 16000: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) + return AASIST_AVAILABLE + + +class ASVSpoofMetric(BaseMetric): + """ASVspoof deepfake detection metric using AASIST model.""" - pred_x = torch.from_numpy(pred_x).unsqueeze(0).float().to(model.device) - model.eval() - with torch.no_grad(): - _, output = model(pred_x) - output = torch.softmax(output, dim=1) - output = output.squeeze(0).cpu().numpy() - return {"asvspoof_score": output[1]} + def _setup(self): + """Initialize ASVspoof-specific components.""" + if not AASIST_AVAILABLE: + raise ImportError( + "AASIST is not properly installed. Please install following https://github.com/clovaai/aasist" + ) + + self.model_tag = self.config.get("model_tag", "default") + self.model_path = self.config.get("model_path", None) + self.model_config = self.config.get("model_config", None) + self.use_gpu = self.config.get("use_gpu", False) + + self.device = "cuda" if self.use_gpu and torch.cuda.is_available() else "cpu" + + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize AASIST model: {str(e)}") from e + + def _setup_model(self): + """Setup the AASIST model.""" + if self.model_path is not None and self.model_config is not None: + with open(self.model_config, "r") as f_json: + config = json.loads(f_json.read()) + model = AASIST(config["model_config"]).to(self.device) + model.load_state_dict( + torch.load(self.model_path, map_location=self.device) + ) + else: + if self.model_tag == "default": + model_root = "./tools/checkpoints/aasist" + model_config = os.path.join(model_root, "config/AASIST.conf") + model_path = os.path.join(model_root, "models/weights/AASIST.pth") + + with open(model_config, "r") as f_json: + config = json.loads(f_json.read()) + model = AASIST(config["model_config"]).to(self.device) + model.load_state_dict( + torch.load(model_path, map_location=self.device) + ) + else: + raise NotImplementedError( + f"Model tag '{self.model_tag}' not implemented" + ) + + model.device = self.device + return model + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate ASVspoof score for audio. + + Args: + predictions: Audio signal to evaluate. + references: Not used for this metric. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing the ASVspoof score. + """ + pred_x = predictions + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate input + if pred_x is None: + raise ValueError("Predicted signal must be provided") + + pred_x = np.asarray(pred_x) + + # NOTE(jiatong): only work for 16000 Hz + if fs != 16000: + pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) + + pred_x = torch.from_numpy(pred_x).unsqueeze(0).float().to(self.device) + self.model.eval() + with torch.no_grad(): + _, output = self.model(pred_x) + output = torch.softmax(output, dim=1) + output = output.squeeze(0).cpu().numpy() + + return {"asvspoof_score": output[1]} + + def get_metadata(self) -> MetricMetadata: + """Return ASVspoof metric metadata.""" + return MetricMetadata( + name="asvspoof", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["torch", "librosa", "numpy"], + description="ASVspoof deepfake detection score using AASIST model for speech authenticity assessment", + paper_reference="https://github.com/clovaai/aasist", + implementation_source="https://github.com/clovaai/aasist", + ) + + +def register_asvspoof_metric(registry): + """Register ASVspoof metric with the registry.""" + metric_metadata = MetricMetadata( + name="asvspoof", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["torch", "librosa", "numpy"], + description="ASVspoof deepfake detection score using AASIST model for speech authenticity assessment", + paper_reference="https://github.com/clovaai/aasist", + implementation_source="https://github.com/clovaai/aasist", + ) + registry.register( + ASVSpoofMetric, metric_metadata, aliases=["ASVSpoof", "asvspoof_score"] + ) if __name__ == "__main__": a = np.random.random(16000) - model = deepfake_detection_model_setup(use_gpu=False) - print(f"metrics: {asvspoof_metric(model, a, 16000)}") + + # Test the new class-based metric + config = {"use_gpu": False} + metric = ASVSpoofMetric(config) + metadata = {"sample_rate": 16000} + score = metric.compute(a, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/audiobox_aesthetics_score.py b/versa/utterance_metrics/audiobox_aesthetics_score.py index 71a05cb..ee3f073 100644 --- a/versa/utterance_metrics/audiobox_aesthetics_score.py +++ b/versa/utterance_metrics/audiobox_aesthetics_score.py @@ -6,87 +6,170 @@ """Module for evaluating audio using AudioBox Aesthetics models.""" import json +import logging import os +from typing import Dict, Any, Optional, Union import numpy as np +logger = logging.getLogger(__name__) + +# Handle optional audiobox_aesthetics dependency try: import audiobox_aesthetics.infer import audiobox_aesthetics.utils + + AUDIOBOX_AESTHETICS_AVAILABLE = True except ImportError: + logger.warning( + "audiobox_aesthetics is not properly installed. " + "Please install with tools/install_audiobox-aesthetics.sh first." + ) audiobox_aesthetics = None + AUDIOBOX_AESTHETICS_AVAILABLE = False +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType -def audiobox_aesthetics_setup( - model_path=None, - batch_size=1, - precision="bf16", - cache_dir="versa_cache/audiobox", - use_huggingface=True, - use_gpu=False, -): - """Set up the AudioBox Aesthetics model for inference. - - Args: - model_path (str, optional): Path to model weights. Defaults to None. - batch_size (int, optional): Batch size for inference. Defaults to 1. - precision (str, optional): Precision for inference. Defaults to "bf16". - cache_dir (str, optional): Directory to cache model. Defaults to "versa_cache/audiobox". - use_huggingface (bool, optional): Whether to use Hugging Face. Defaults to True. - use_gpu (bool, optional): Whether to use GPU. Defaults to False. - Returns: - AesWavlmPredictorMultiOutput: The loaded model. +class AudioBoxAestheticsNotAvailableError(RuntimeError): + """Exception raised when AudioBox Aesthetics is required but not available.""" + + pass + - Raises: - ImportError: If audiobox_aesthetics is not installed. +def is_audiobox_aesthetics_available(): """ - if audiobox_aesthetics is None: - raise ImportError( - "Please install with tools/install_audiobox-aesthetics.sh first." - ) + Check if the AudioBox Aesthetics package is available. - device = "cuda" if use_gpu else "cpu" + Returns: + bool: True if AudioBox Aesthetics is available, False otherwise. + """ + return AUDIOBOX_AESTHETICS_AVAILABLE - if model_path is None: - if use_huggingface: - model_path = audiobox_aesthetics.utils.load_model(model_path) - else: - os.makedirs(cache_dir, exist_ok=True) - model_path = os.path.join( - cache_dir, audiobox_aesthetics.utils.DEFAULT_CKPT_FNAME + +class AudioBoxAestheticsMetric(BaseMetric): + """AudioBox Aesthetics metric for audio quality assessment.""" + + def _setup(self): + """Initialize AudioBox Aesthetics-specific components.""" + if not AUDIOBOX_AESTHETICS_AVAILABLE: + raise ImportError( + "audiobox_aesthetics is not properly installed. " + "Please install with tools/install_audiobox-aesthetics.sh first." ) - model_url = audiobox_aesthetics.utils.DEFAULT_S3_URL - if not os.path.exists(model_path): - print(f"Downloading model from {model_url} to {model_path}") - audiobox_aesthetics.utils.download_file(model_url, model_path) - - predictor = audiobox_aesthetics.infer.AesWavlmPredictorMultiOutput( - checkpoint_pth=model_path, - device=device, - batch_size=batch_size, - precision=precision, - ) - return predictor + self.model_path = self.config.get("model_path", None) + self.batch_size = self.config.get("batch_size", 1) + self.precision = self.config.get("precision", "bf16") + self.cache_dir = self.config.get("cache_dir", "versa_cache/audiobox") + self.use_huggingface = self.config.get("use_huggingface", True) + self.use_gpu = self.config.get("use_gpu", False) + + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError( + f"Failed to initialize AudioBox Aesthetics model: {str(e)}" + ) from e + + def _setup_model(self): + """Setup the AudioBox Aesthetics model.""" + device = "cuda" if self.use_gpu else "cpu" + + if self.model_path is None: + if self.use_huggingface: + model_path = audiobox_aesthetics.utils.load_model(self.model_path) + else: + os.makedirs(self.cache_dir, exist_ok=True) + model_path = os.path.join( + self.cache_dir, audiobox_aesthetics.utils.DEFAULT_CKPT_FNAME + ) + model_url = audiobox_aesthetics.utils.DEFAULT_S3_URL + if not os.path.exists(model_path): + print(f"Downloading model from {model_url} to {model_path}") + audiobox_aesthetics.utils.download_file(model_url, model_path) + else: + model_path = self.model_path -def audiobox_aesthetics_score(model, pred_x, fs): - """Calculate AudioBox Aesthetics scores for audio. + predictor = audiobox_aesthetics.infer.AesWavlmPredictorMultiOutput( + checkpoint_pth=model_path, + device=device, + batch_size=self.batch_size, + precision=self.precision, + ) + return predictor + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate AudioBox Aesthetics scores for audio. + + Args: + predictions: Audio signal to evaluate. + references: Not used for this metric. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing the AudioBox Aesthetics scores. + """ + pred_x = predictions + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate input + if pred_x is None: + raise ValueError("Predicted signal must be provided") + + pred_x = np.asarray(pred_x) + + output = json.loads(self.model.forward_versa([(pred_x, fs)])[0]) + output = {"audiobox_aesthetics_" + k: v for k, v in output.items()} + return output + + def get_metadata(self) -> MetricMetadata: + """Return AudioBox Aesthetics metric metadata.""" + return MetricMetadata( + name="audiobox_aesthetics", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["audiobox_aesthetics", "numpy"], + description="AudioBox Aesthetics scores for audio quality assessment using WavLM-based models", + paper_reference="https://github.com/facebookresearch/audiobox-aesthetics", + implementation_source="https://github.com/facebookresearch/audiobox-aesthetics", + ) - Args: - model (AesWavlmPredictorMultiOutput): The loaded model. - pred_x (np.ndarray): Audio signal. - fs (int): Sampling rate. - Returns: - dict: Dictionary containing the AudioBox Aesthetics scores. - """ - output = json.loads(model.forward_versa([(pred_x, fs)])[0]) - output = {"audiobox_aesthetics_" + k: v for k, v in output.items()} - return output +def register_audiobox_aesthetics_metric(registry): + """Register AudioBox Aesthetics metric with the registry.""" + metric_metadata = MetricMetadata( + name="audiobox_aesthetics", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["audiobox_aesthetics", "numpy"], + description="AudioBox Aesthetics scores for audio quality assessment using WavLM-based models", + paper_reference="https://github.com/facebookresearch/audiobox-aesthetics", + implementation_source="https://github.com/facebookresearch/audiobox-aesthetics", + ) + registry.register( + AudioBoxAestheticsMetric, + metric_metadata, + aliases=["AudioBoxAesthetics", "audiobox_aesthetics"], + ) if __name__ == "__main__": a = np.random.random(16000) - model = audiobox_aesthetics_setup() - print(f"metrics: {audiobox_aesthetics_score(model, a, 16000)}") + + # Test the new class-based metric + config = {"use_gpu": False} + metric = AudioBoxAestheticsMetric(config) + metadata = {"sample_rate": 16000} + score = metric.compute(a, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/cdpam_distance.py b/versa/utterance_metrics/cdpam_distance.py index 772e23e..09874c4 100644 --- a/versa/utterance_metrics/cdpam_distance.py +++ b/versa/utterance_metrics/cdpam_distance.py @@ -1,33 +1,166 @@ -import torch +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Module for CDPAM distance metrics.""" + +import logging +from functools import partial +from typing import Dict, Any, Optional, Union + import librosa import numpy as np -from functools import partial -import cdpam +import torch + +logger = logging.getLogger(__name__) + +# Handle optional cdpam dependency +try: + import cdpam + + CDPAM_AVAILABLE = True +except ImportError: + logger.warning("cdpam is not properly installed. " "Please install cdpam and retry") + cdpam = None + CDPAM_AVAILABLE = False + +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + + +class CdpamNotAvailableError(RuntimeError): + """Exception raised when cdpam is required but not available.""" + + pass + + +def is_cdpam_available(): + """ + Check if the cdpam package is available. -TARGET_FS = 22050 + Returns: + bool: True if cdpam is available, False otherwise. + """ + return CDPAM_AVAILABLE -def cdpam_model_setup(use_gpu=False): - device = "cpu" if not use_gpu else "cuda" - _original_torch_load = torch.load - torch.load = partial(torch.load, weights_only=False) - model = cdpam.CDPAM(dev=device) - torch.load = _original_torch_load - return model +class CdpamDistanceMetric(BaseMetric): + """CDPAM distance metric.""" + TARGET_FS = 22050 -def cdpam_metric(model, pred_x, gt_x, fs): - if fs != TARGET_FS: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=TARGET_FS) - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=TARGET_FS) - pred_x = (torch.from_numpy(pred_x).unsqueeze(0) * 32768).round() - gt_x = (torch.from_numpy(gt_x).unsqueeze(0) * 32768).round() - dist = model.forward(gt_x, pred_x) - return {"cdpam_distance": dist.detach().cpu().numpy().item()} + def _setup(self): + """Initialize CDPAM-specific components.""" + if not CDPAM_AVAILABLE: + raise ImportError( + "cdpam is not properly installed. " "Please install cdpam and retry" + ) + + self.use_gpu = self.config.get("use_gpu", False) + + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize CDPAM model: {str(e)}") from e + + def _setup_model(self): + """Setup the CDPAM model.""" + device = "cpu" if not self.use_gpu else "cuda" + # Suppress PyTorch config registration warnings during model loading + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="Skipping config registration for" + ) + _original_torch_load = torch.load + torch.load = partial(torch.load, weights_only=False) + model = cdpam.CDPAM(dev=device) + torch.load = _original_torch_load + return model + + def compute( + self, predictions: Any, references: Any, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate CDPAM distance between two audio samples. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing the CDPAM distance score. + """ + pred_x = predictions + gt_x = references + fs = metadata.get("sample_rate", 22050) if metadata else 22050 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + if gt_x is None: + raise ValueError("Reference signal must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + if fs != self.TARGET_FS: + pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=self.TARGET_FS) + gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=self.TARGET_FS) + + pred_x = (torch.from_numpy(pred_x).unsqueeze(0) * 32768).round() + gt_x = (torch.from_numpy(gt_x).unsqueeze(0) * 32768).round() + dist = self.model.forward(gt_x, pred_x) + + return {"cdpam_distance": dist.detach().cpu().numpy().item()} + + def get_metadata(self) -> MetricMetadata: + """Return CDPAM distance metric metadata.""" + return MetricMetadata( + name="cdpam_distance", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["cdpam", "torch", "librosa", "numpy"], + description="CDPAM distance between audio samples", + paper_reference="https://github.com/facebookresearch/audiocraft", + implementation_source="https://github.com/facebookresearch/audiocraft", + ) + + +def register_cdpam_distance_metric(registry): + """Register CDPAM distance metric with the registry.""" + metric_metadata = MetricMetadata( + name="cdpam_distance", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["cdpam", "torch", "librosa", "numpy"], + description="CDPAM distance between audio samples", + paper_reference="https://github.com/facebookresearch/audiocraft", + implementation_source="https://github.com/facebookresearch/audiocraft", + ) + registry.register( + CdpamDistanceMetric, + metric_metadata, + aliases=["CdpamDistance", "cdpam_distance", "cdpam"], + ) if __name__ == "__main__": a = np.random.random(22050) b = np.random.random(22050) - model = cdpam_model_setup() - print("metrics: {}".format(cdpam_metric(model, a, b, 22050))) + + # Test the new class-based metric + config = {"use_gpu": False} + metric = CdpamDistanceMetric(config) + metadata = {"sample_rate": 22050} + score = metric.compute(a, b, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/chroma_alignment.py b/versa/utterance_metrics/chroma_alignment.py index 54f030d..4fb8b3f 100644 --- a/versa/utterance_metrics/chroma_alignment.py +++ b/versa/utterance_metrics/chroma_alignment.py @@ -4,10 +4,16 @@ # Chroma-based distance estimation with dynamic programming alignment # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +import logging +from typing import Dict, Any, Optional, Union, Tuple, List + import librosa import numpy as np from scipy.spatial.distance import cosine, euclidean -from typing import Tuple, Dict, Optional + +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + +logger = logging.getLogger(__name__) def calculate_chroma_features(audio, sr=22050, feature_type="stft", **kwargs): @@ -161,127 +167,166 @@ def calculate_chroma_distance( return dtw_dist, alignment_path -def chroma_metric(pred_x, gt_x, sr=22050, return_alignment=False, scale_factor=100.0): - """ - Calculate multiple chroma-based distance metrics. +class ChromaAlignmentMetric(BaseMetric): + """Chroma-based distance estimation with dynamic programming alignment.""" - Args: - pred_x: Predicted audio signal (1D numpy array) - gt_x: Ground truth audio signal (1D numpy array) - sr: Sample rate - return_alignment: Whether to return alignment paths - scale_factor: Multiplicative scaling factor for distances - - Returns: - Dictionary of chroma distance metrics - """ - # Ensure 1D arrays - if pred_x.ndim > 1: - pred_x = pred_x.flatten() - if gt_x.ndim > 1: - gt_x = gt_x.flatten() - - results = {} - alignments = {} if return_alignment else None - - # Different chroma feature types - feature_types = [ - "stft", - "cqt", - "cens", - ] # 'vqt' might not be available in all librosa versions - distance_metrics = ["cosine", "euclidean"] - - for feat_type in feature_types: - for dist_metric in distance_metrics: - try: - dtw_dist, alignment = calculate_chroma_distance( - pred_x, - gt_x, - sr=sr, - feature_type=feat_type, - distance_metric=dist_metric, - scale_factor=scale_factor, - ) - - metric_name = f"chroma_{feat_type}_{dist_metric}_dtw" - results[metric_name] = dtw_dist - - if return_alignment: - alignments[metric_name] = alignment - - except Exception as e: - print( - f"Warning: Could not calculate {feat_type} with {dist_metric}: {e}" - ) - continue - - # Add additional scaled variants - try: - # Raw DTW distance (no path normalization, higher scale) - dtw_dist_raw, _ = calculate_chroma_distance( - pred_x, - gt_x, - sr=sr, - feature_type="stft", - distance_metric="cosine", - scale_factor=1000.0, - normalize_by_path=True, + def _setup(self): + """Initialize Chroma Alignment-specific components.""" + self.sample_rate = self.config.get("sample_rate", 22050) + self.feature_types = self.config.get("feature_types", ["stft", "cqt", "cens"]) + self.distance_metrics = self.config.get( + "distance_metrics", ["cosine", "euclidean"] ) - results["chroma_stft_cosine_dtw_raw"] = dtw_dist_raw - - # Log-scaled distance - dtw_dist_base, _ = calculate_chroma_distance( - pred_x, - gt_x, - sr=sr, - feature_type="stft", - distance_metric="cosine", - scale_factor=1.0, - normalize_by_path=True, + self.scale_factor = self.config.get("scale_factor", 100.0) + self.normalize = self.config.get("normalize", True) + self.normalize_by_path = self.config.get("normalize_by_path", True) + self.return_alignment = self.config.get("return_alignment", False) + self.chroma_kwargs = self.config.get("chroma_kwargs", {}) + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate chroma-based distance metrics. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing chroma distance metrics. + """ + pred_x = predictions + gt_x = references + sr = ( + metadata.get("sample_rate", self.sample_rate) + if metadata + else self.sample_rate ) - results["chroma_stft_cosine_dtw_log"] = -np.log10(dtw_dist_base + 1e-10) * 10 - except Exception as e: - print(f"Warning: Could not calculate additional scaled metrics: {e}") - - if return_alignment: - return results, alignments - return results + # Validate inputs + if pred_x is None or gt_x is None: + raise ValueError("Both predicted and ground truth signals must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + # Ensure 1D arrays + if pred_x.ndim > 1: + pred_x = pred_x.flatten() + if gt_x.ndim > 1: + gt_x = gt_x.flatten() + + results = {} + alignments = {} if self.return_alignment else None + + # Calculate metrics for different feature types and distance metrics + for feat_type in self.feature_types: + for dist_metric in self.distance_metrics: + try: + dtw_dist, alignment = calculate_chroma_distance( + pred_x, + gt_x, + sr=sr, + feature_type=feat_type, + distance_metric=dist_metric, + scale_factor=self.scale_factor, + normalize=self.normalize, + normalize_by_path=self.normalize_by_path, + **self.chroma_kwargs, + ) + + metric_name = f"chroma_{feat_type}_{dist_metric}_dtw" + results[metric_name] = dtw_dist + + if self.return_alignment and alignments is not None: + alignments[metric_name] = alignment + + except Exception as e: + logger.warning( + f"Could not calculate {feat_type} with {dist_metric}: {e}" + ) + continue + + # Add additional scaled variants + try: + # Raw DTW distance (no path normalization, higher scale) + dtw_dist_raw, _ = calculate_chroma_distance( + pred_x, + gt_x, + sr=sr, + feature_type="stft", + distance_metric="cosine", + scale_factor=1000.0, + normalize_by_path=True, + normalize=self.normalize, + **self.chroma_kwargs, + ) + results["chroma_stft_cosine_dtw_raw"] = dtw_dist_raw + + # Log-scaled distance + dtw_dist_base, _ = calculate_chroma_distance( + pred_x, + gt_x, + sr=sr, + feature_type="stft", + distance_metric="cosine", + scale_factor=1.0, + normalize_by_path=True, + normalize=self.normalize, + **self.chroma_kwargs, + ) + results["chroma_stft_cosine_dtw_log"] = ( + -np.log10(dtw_dist_base + 1e-10) * 10 + ) + except Exception as e: + logger.warning(f"Could not calculate additional scaled metrics: {e}") + + if self.return_alignment and alignments is not None: + results["alignments"] = alignments + + return results + + def get_metadata(self) -> MetricMetadata: + """Return Chroma Alignment metric metadata.""" + return MetricMetadata( + name="chroma_alignment", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["librosa", "numpy", "scipy"], + description="Chroma-based distance estimation with dynamic programming alignment for audio similarity assessment", + paper_reference="https://librosa.org/doc/latest/generated/librosa.feature.chroma_stft.html", + implementation_source="https://github.com/librosa/librosa", + ) -def simple_chroma_distance( - pred_x, - gt_x, - sr=22050, - feature_type="stft", - distance_metric="cosine", - scale_factor=100.0, -): - """ - Args: - pred_x: Predicted audio signal - gt_x: Ground truth audio signal - sr: Sample rate - feature_type: Chroma feature type - distance_metric: Distance metric - scale_factor: Multiplicative scaling factor - Returns: - DTW distance value - """ - dtw_dist, _ = calculate_chroma_distance( - pred_x, - gt_x, - sr=sr, - feature_type=feature_type, - distance_metric=distance_metric, - scale_factor=scale_factor, +def register_chroma_alignment_metric(registry): + """Register Chroma Alignment metric with the registry.""" + metric_metadata = MetricMetadata( + name="chroma_alignment", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["librosa", "numpy", "scipy"], + description="Chroma-based distance estimation with dynamic programming alignment for audio similarity assessment", + paper_reference="https://librosa.org/doc/latest/generated/librosa.feature.chroma_stft.html", + implementation_source="https://github.com/librosa/librosa", + ) + registry.register( + ChromaAlignmentMetric, + metric_metadata, + aliases=["ChromaAlignment", "chroma_alignment"], ) - return dtw_dist -# Debug code if __name__ == "__main__": # Create test signals with different lengths sr = 22050 @@ -295,48 +340,9 @@ def simple_chroma_distance( pred_signal = np.sin(2 * np.pi * 440 * t1) # A4 note gt_signal = np.sin(2 * np.pi * 440 * t2) # Same note, different length - # Create a more different signal for testing - diff_signal = np.sin(2 * np.pi * 554.37 * t1) # C#5 note (different pitch) - - print(f"Predicted signal length: {len(pred_signal)} samples ({duration1}s)") - print(f"Ground truth signal length: {len(gt_signal)} samples ({duration2}s)") - - # Calculate chroma metrics for similar signals - print("\n=== SIMILAR SIGNALS (same pitch, different length) ===") - metrics_similar = chroma_metric(pred_signal, gt_signal, sr=sr, scale_factor=100.0) - for metric_name, value in metrics_similar.items(): - print(f"{metric_name}: {value:.4f}") - - # Calculate chroma metrics for different signals - print("\n=== DIFFERENT SIGNALS (different pitch) ===") - metrics_different = chroma_metric( - pred_signal, diff_signal, sr=sr, scale_factor=100.0 - ) - for metric_name, value in metrics_different.items(): - print(f"{metric_name}: {value:.4f}") - - # Simple interface examples with different scale factors - print("\n=== SIMPLE INTERFACE WITH DIFFERENT SCALES ===") - print( - f"Scale 1.0: {simple_chroma_distance(pred_signal, gt_signal, sr=sr, scale_factor=1.0):.4f}" - ) - print( - f"Scale 10.0: {simple_chroma_distance(pred_signal, gt_signal, sr=sr, scale_factor=10.0):.4f}" - ) - print( - f"Scale 100.0: {simple_chroma_distance(pred_signal, gt_signal, sr=sr, scale_factor=100.0):.4f}" - ) - print( - f"Scale 1000.0: {simple_chroma_distance(pred_signal, gt_signal, sr=sr, scale_factor=1000.0):.4f}" - ) - - # Test with random signals (should give larger distances) - print("\n=== RANDOM SIGNALS (should give larger distances) ===") - random_signal1 = np.random.randn(int(sr * 2.0)) - random_signal2 = np.random.randn(int(sr * 2.0)) - - metrics_random = chroma_metric( - random_signal1, random_signal2, sr=sr, scale_factor=100.0 - ) - for metric_name, value in list(metrics_random.items())[:3]: # Show first 3 metrics - print(f"{metric_name}: {value:.4f}") + # Test the new class-based metric + config = {"scale_factor": 100.0} + metric = ChromaAlignmentMetric(config) + metadata = {"sample_rate": sr} + score = metric.compute(pred_signal, gt_signal, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/discrete_speech.py b/versa/utterance_metrics/discrete_speech.py index f628906..eeec12f 100644 --- a/versa/utterance_metrics/discrete_speech.py +++ b/versa/utterance_metrics/discrete_speech.py @@ -6,93 +6,201 @@ """Module for discrete speech metrics evaluation.""" import logging +from typing import Dict, Any, Optional, Union import librosa import numpy as np -try: - from discrete_speech_metrics import SpeechBERTScore, SpeechBLEU, SpeechTokenDistance -except ImportError: - raise ImportError("Please install discrete_speech_metrics and retry") - logger = logging.getLogger(__name__) +# Handle optional discrete_speech_metrics dependency +try: + from discrete_speech_metrics import SpeechBERTScore, SpeechBLEU, SpeechTokenDistance -def discrete_speech_setup(use_gpu=False): - """Set up discrete speech metrics. + DISCRETE_SPEECH_AVAILABLE = True +except ImportError: + logger.warning( + "discrete_speech_metrics is not properly installed. " + "Please install discrete_speech_metrics and retry" + ) + SpeechBERTScore = None + SpeechBLEU = None + SpeechTokenDistance = None + DISCRETE_SPEECH_AVAILABLE = False - Args: - use_gpu (bool, optional): Whether to use GPU. Defaults to False. +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType - Returns: - dict: Dictionary containing the initialized metrics. - """ - # NOTE(jiatong) existing discrete speech metrics only works for 16khz - # We keep the paper best setting. To use other settings, please conduct the - # test on your own. - speech_bert = SpeechBERTScore( - sr=16000, model_type="wavlm-large", layer=14, use_gpu=use_gpu - ) - speech_bleu = SpeechBLEU( - sr=16000, - model_type="hubert-base", - vocab=200, - layer=11, - n_ngram=2, - remove_repetition=True, - use_gpu=use_gpu, - ) - speech_token_distance = SpeechTokenDistance( - sr=16000, - model_type="hubert-base", - vocab=200, - layer=6, - distance_type="jaro-winkler", - remove_repetition=False, - use_gpu=use_gpu, - ) - return { - "speech_bert": speech_bert, - "speech_bleu": speech_bleu, - "speech_token_distance": speech_token_distance, - } +class DiscreteSpeechNotAvailableError(RuntimeError): + """Exception raised when discrete_speech_metrics is required but not available.""" + pass -def discrete_speech_metric(discrete_speech_predictors, pred_x, gt_x, fs): - """Calculate discrete speech metrics. - Args: - discrete_speech_predictors (dict): Dictionary of speech metrics. - pred_x (np.ndarray): Predicted audio signal. - gt_x (np.ndarray): Ground truth audio signal. - fs (int): Sampling rate. +def is_discrete_speech_available(): + """ + Check if the discrete_speech_metrics package is available. Returns: - dict: Dictionary containing the metric scores. - - Raises: - NotImplementedError: If an unsupported metric is provided. + bool: True if discrete_speech_metrics is available, False otherwise. """ - scores = {} - - if fs != 16000: - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=16000) - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - - for key in discrete_speech_predictors.keys(): - if key == "speech_bert": - score, _, _ = discrete_speech_predictors[key].score(gt_x, pred_x) - elif key == "speech_bleu" or key == "speech_token_distance": - score = discrete_speech_predictors[key].score(gt_x, pred_x) - else: - raise NotImplementedError(f"Not supported {key}") - scores[key] = score - return scores + return DISCRETE_SPEECH_AVAILABLE + + +class DiscreteSpeechMetric(BaseMetric): + """Discrete speech metrics for audio evaluation.""" + + def _setup(self): + """Initialize Discrete Speech-specific components.""" + if not DISCRETE_SPEECH_AVAILABLE: + raise ImportError( + "discrete_speech_metrics is not properly installed. " + "Please install discrete_speech_metrics and retry" + ) + + self.use_gpu = self.config.get("use_gpu", False) + self.sample_rate = self.config.get("sample_rate", 16000) + + # NOTE(jiatong) existing discrete speech metrics only works for 16khz + # We keep the paper best setting. To use other settings, please conduct the + # test on your own. + + try: + self.speech_bert = SpeechBERTScore( + sr=self.sample_rate, + model_type="wavlm-large", + layer=14, + use_gpu=self.use_gpu, + ) + self.speech_bleu = SpeechBLEU( + sr=self.sample_rate, + model_type="hubert-base", + vocab=200, + layer=11, + n_ngram=2, + remove_repetition=True, + use_gpu=self.use_gpu, + ) + self.speech_token_distance = SpeechTokenDistance( + sr=self.sample_rate, + model_type="hubert-base", + vocab=200, + layer=6, + distance_type="jaro-winkler", + remove_repetition=False, + use_gpu=self.use_gpu, + ) + except Exception as e: + raise RuntimeError( + f"Failed to initialize discrete speech metrics: {str(e)}" + ) from e + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate discrete speech metrics. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing the metric scores. + """ + pred_x = predictions + gt_x = references + fs = ( + metadata.get("sample_rate", self.sample_rate) + if metadata + else self.sample_rate + ) + + # Validate inputs + if pred_x is None or gt_x is None: + raise ValueError("Both predicted and ground truth signals must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + scores = {} + + if fs != self.sample_rate: + gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=self.sample_rate) + pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=self.sample_rate) + + # Calculate SpeechBERT score + try: + score, _, _ = self.speech_bert.score(gt_x, pred_x) + scores["speech_bert"] = score + except Exception as e: + logger.warning(f"Could not calculate SpeechBERT score: {e}") + scores["speech_bert"] = 0.0 + + # Calculate SpeechBLEU score + try: + score = self.speech_bleu.score(gt_x, pred_x) + scores["speech_bleu"] = score + except Exception as e: + logger.warning(f"Could not calculate SpeechBLEU score: {e}") + scores["speech_bleu"] = 0.0 + + # Calculate SpeechTokenDistance score + try: + score = self.speech_token_distance.score(gt_x, pred_x) + scores["speech_token_distance"] = score + except Exception as e: + logger.warning(f"Could not calculate SpeechTokenDistance score: {e}") + scores["speech_token_distance"] = 0.0 + + return scores + + def get_metadata(self) -> MetricMetadata: + """Return Discrete Speech metric metadata.""" + return MetricMetadata( + name="discrete_speech", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["discrete_speech_metrics", "librosa", "numpy"], + description="Discrete speech metrics including SpeechBERT, SpeechBLEU, and SpeechTokenDistance for audio evaluation", + paper_reference="https://github.com/ftshijt/discrete_speech_metrics", + implementation_source="https://github.com/ftshijt/discrete_speech_metrics", + ) + + +def register_discrete_speech_metric(registry): + """Register Discrete Speech metric with the registry.""" + metric_metadata = MetricMetadata( + name="discrete_speech", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["discrete_speech_metrics", "librosa", "numpy"], + description="Discrete speech metrics including SpeechBERT, SpeechBLEU, and SpeechTokenDistance for audio evaluation", + paper_reference="https://github.com/ftshijt/discrete_speech_metrics", + implementation_source="https://github.com/ftshijt/discrete_speech_metrics", + ) + registry.register( + DiscreteSpeechMetric, + metric_metadata, + aliases=["DiscreteSpeech", "discrete_speech"], + ) if __name__ == "__main__": a = np.random.random(16000) b = np.random.random(16000) - predictor = discrete_speech_setup() - print(discrete_speech_metric(predictor, a, b, 16000)) + + # Test the new class-based metric + config = {"use_gpu": False} + metric = DiscreteSpeechMetric(config) + metadata = {"sample_rate": 16000} + score = metric.compute(a, b, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/dpam_distance.py b/versa/utterance_metrics/dpam_distance.py index 19e5ac5..8754c9f 100644 --- a/versa/utterance_metrics/dpam_distance.py +++ b/versa/utterance_metrics/dpam_distance.py @@ -1,14 +1,24 @@ -import torch -import torch.nn as nn -import librosa -import numpy as np -import urllib.request +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Module for DPAM distance metrics.""" + import logging +import urllib.request import filelock from pathlib import Path +from typing import Dict, Any, Optional, Union -TARGET_FS = 22050 -MODEL_URL = "https://raw.githubusercontent.com/adrienchaton/PerceptualAudio_Pytorch/refs/heads/master/pretrained/dataset_combined_linear_tshrink.pth" +import librosa +import numpy as np +import torch +import torch.nn as nn + +logger = logging.getLogger(__name__) + +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType class lossnet(nn.Module): @@ -77,37 +87,134 @@ def forward(self, xref, xper): return dist -def dpam_model_setup(cache_dir="versa_cache", use_gpu=False): - device = "cpu" if not use_gpu else "cuda" - model_path = Path(cache_dir) / "dpam" / "dataset_combined_linear.pth" - model_path.parent.mkdir(parents=True, exist_ok=True) - with filelock.FileLock(model_path.with_suffix(".lock")): - if not model_path.exists(): - logging.info(f"Downloading model to {model_path}...") - urllib.request.urlretrieve(MODEL_URL, model_path) - logging.info("Download complete.") - state = torch.load(model_path, map_location="cpu", weights_only=False)["state"] - prefix = "model_dist." - state = {k[len(prefix) :]: v for k, v in state.items() if k.startswith(prefix)} - model = lossnet(nconv=14, nchan=16, dp=0, dist_act="tshrink") - model.load_state_dict(state) - model.to(device) - model.eval() - return model - - -def dpam_metric(model, pred_x, gt_x, fs): - if fs != TARGET_FS: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=TARGET_FS) - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=TARGET_FS) - pred_x = torch.from_numpy(pred_x).unsqueeze(0).float() - gt_x = torch.from_numpy(gt_x).unsqueeze(0).float() - dist = model(gt_x, pred_x) - return {"dpam_distance": dist.detach().cpu().numpy().item()} +class DpamDistanceMetric(BaseMetric): + """DPAM distance metric.""" + + TARGET_FS = 22050 + MODEL_URL = "https://raw.githubusercontent.com/adrienchaton/PerceptualAudio_Pytorch/refs/heads/master/pretrained/dataset_combined_linear_tshrink.pth" + + def _setup(self): + """Initialize DPAM-specific components.""" + self.use_gpu = self.config.get("use_gpu", False) + self.cache_dir = self.config.get("cache_dir", "versa_cache") + + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize DPAM model: {str(e)}") from e + + def _setup_model(self): + """Setup the DPAM model.""" + device = "cpu" if not self.use_gpu else "cuda" + model_path = Path(self.cache_dir) / "dpam" / "dataset_combined_linear.pth" + model_path.parent.mkdir(parents=True, exist_ok=True) + + with filelock.FileLock(model_path.with_suffix(".lock")): + if not model_path.exists(): + logger.info(f"Downloading model to {model_path}...") + urllib.request.urlretrieve(self.MODEL_URL, model_path) + logger.info("Download complete.") + + # Suppress PyTorch config registration warnings during model loading + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="Skipping config registration for" + ) + checkpoint = torch.load(model_path, map_location="cpu", weights_only=False) + + state = checkpoint["state"] + prefix = "model_dist." + state = {k[len(prefix) :]: v for k, v in state.items() if k.startswith(prefix)} + model = lossnet(nconv=14, nchan=16, dp=0, dist_act="tshrink") + model.load_state_dict(state) + model.to(device) + model.eval() + return model + + def compute( + self, predictions: Any, references: Any, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate DPAM distance between two audio samples. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing the DPAM distance score. + """ + pred_x = predictions + gt_x = references + fs = metadata.get("sample_rate", 22050) if metadata else 22050 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + if gt_x is None: + raise ValueError("Reference signal must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + if fs != self.TARGET_FS: + pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=self.TARGET_FS) + gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=self.TARGET_FS) + + pred_x = torch.from_numpy(pred_x).unsqueeze(0).float() + gt_x = torch.from_numpy(gt_x).unsqueeze(0).float() + dist = self.model(gt_x, pred_x) + + return {"dpam_distance": dist.detach().cpu().numpy().item()} + + def get_metadata(self) -> MetricMetadata: + """Return DPAM distance metric metadata.""" + return MetricMetadata( + name="dpam_distance", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["torch", "librosa", "numpy", "filelock"], + description="DPAM distance between audio samples", + paper_reference="https://github.com/adrienchaton/PerceptualAudio_Pytorch", + implementation_source="https://github.com/adrienchaton/PerceptualAudio_Pytorch", + ) + + +def register_dpam_distance_metric(registry): + """Register DPAM distance metric with the registry.""" + metric_metadata = MetricMetadata( + name="dpam_distance", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["torch", "librosa", "numpy", "filelock"], + description="DPAM distance between audio samples", + paper_reference="https://github.com/adrienchaton/PerceptualAudio_Pytorch", + implementation_source="https://github.com/adrienchaton/PerceptualAudio_Pytorch", + ) + registry.register( + DpamDistanceMetric, + metric_metadata, + aliases=["DpamDistance", "dpam_distance", "dpam"], + ) if __name__ == "__main__": a = np.random.random(22050) b = np.random.random(22050) - model = dpam_model_setup() - print("metrics: {}".format(dpam_metric(model, a, b, 22050))) + + # Test the new class-based metric + config = {"use_gpu": False} + metric = DpamDistanceMetric(config) + metadata = {"sample_rate": 22050} + score = metric.compute(a, b, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/emo_similarity.py b/versa/utterance_metrics/emo_similarity.py new file mode 100644 index 0000000..c1f0235 --- /dev/null +++ b/versa/utterance_metrics/emo_similarity.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Module for emotion similarity metrics using EMO2VEC.""" + +import logging +import os +from pathlib import Path +from typing import Dict, Any, Optional, Union + +import librosa +import numpy as np + +logger = logging.getLogger(__name__) + +# Handle optional emo2vec dependency +try: + import emo2vec_versa + from emo2vec_versa.emo2vec_class import EMO2VEC + + EMO2VEC_AVAILABLE = True +except ImportError: + logger.info( + "emo2vec is not installed. Please install the package via " + "`tools/install_emo2vec.sh`" + ) + EMO2VEC = None + EMO2VEC_AVAILABLE = False + +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + + +class Emo2vecNotAvailableError(RuntimeError): + """Exception raised when emo2vec is required but not available.""" + + pass + + +def is_emo2vec_available(): + """ + Check if the emo2vec package is available. + + Returns: + bool: True if emo2vec is available, False otherwise. + """ + return EMO2VEC_AVAILABLE + + +class Emo2vecMetric(BaseMetric): + """Emotion similarity metric using EMO2VEC.""" + + def _setup(self): + """Initialize Emotion-specific components.""" + if not EMO2VEC_AVAILABLE: + raise ImportError( + "emo2vec_versa not found. Please install from tools/installers" + ) + + self.model_tag = self.config.get("model_tag", "default") + self.model_path = self.config.get("model_path", None) + self.use_gpu = self.config.get("use_gpu", False) + + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize Emotion model: {str(e)}") from e + + def _setup_model(self): + """Setup the Emotion model.""" + if self.model_path is not None: + model = EMO2VEC(self.model_path, use_gpu=self.use_gpu) + else: + if self.model_tag == "default" or self.model_tag == "base": + model_path = ( + Path(os.path.abspath(emo2vec_versa.__file__)).parent + / "emotion2vec_base.pt" + ) + else: + raise ValueError(f"Unknown model_tag for emo2vec: {self.model_tag}") + + # check if model exists + if not model_path.exists(): + raise FileNotFoundError(f"Model file not found: {model_path}") + + model = EMO2VEC(checkpoint_dir=str(model_path), use_gpu=self.use_gpu) + + return model + + def compute( + self, predictions: Any, references: Any, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate emotion similarity between two audio samples. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing the emotion similarity score. + """ + pred_x = predictions + gt_x = references + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + if gt_x is None: + raise ValueError("Reference signal must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + # NOTE(jiatong): only work for 16000 Hz + if fs != 16000: + gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=16000) + pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) + + embedding_gen = self.model.extract_feature(pred_x, fs=16000) + embedding_gt = self.model.extract_feature(gt_x, fs=16000) + similarity = np.dot(embedding_gen, embedding_gt) / ( + np.linalg.norm(embedding_gen) * np.linalg.norm(embedding_gt) + ) + + return {"emotion_similarity": float(similarity)} + + def get_metadata(self) -> MetricMetadata: + """Return Emotion metric metadata.""" + return MetricMetadata( + name="emotion", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["emo2vec_versa", "librosa", "numpy"], + description="Emotion similarity between audio samples using EMO2VEC", + paper_reference="https://github.com/ddlBoJack/emotion2vec", + implementation_source="https://github.com/ddlBoJack/emotion2vec", + ) + + +def register_emo2vec_metric(registry): + """Register Emotion metric with the registry.""" + metric_metadata = MetricMetadata( + name="emotion", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["emo2vec_versa", "librosa", "numpy"], + description="Emotion similarity between audio samples using EMO2VEC", + paper_reference="https://github.com/ddlBoJack/emotion2vec", + implementation_source="https://github.com/ddlBoJack/emotion2vec", + ) + registry.register( + Emo2vecMetric, + metric_metadata, + aliases=["Emotion", "emotion", "emo2vec_similarity"], + ) + + +if __name__ == "__main__": + a = np.random.random(16000) + b = np.random.random(16000) + + # Test the new class-based metric + config = {"use_gpu": False} + metric = Emo2vecMetric(config) + metadata = {"sample_rate": 16000} + score = metric.compute(a, b, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/emo_vad.py b/versa/utterance_metrics/emo_vad.py index 48def8f..7bc8caa 100644 --- a/versa/utterance_metrics/emo_vad.py +++ b/versa/utterance_metrics/emo_vad.py @@ -8,19 +8,51 @@ import logging import os from pathlib import Path +from typing import Dict, Any, Optional, Union import librosa import numpy as np +import torch +import torch.nn as nn logger = logging.getLogger(__name__) -import torch -import torch.nn as nn -from transformers import Wav2Vec2Processor -from transformers.models.wav2vec2.modeling_wav2vec2 import ( - Wav2Vec2Model, - Wav2Vec2PreTrainedModel, -) +# Handle optional transformers dependency +try: + from transformers import Wav2Vec2Processor + from transformers.models.wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Model, + Wav2Vec2PreTrainedModel, + ) + + TRANSFORMERS_AVAILABLE = True +except ImportError: + logger.warning( + "transformers is not properly installed. " + "Please install transformers and retry" + ) + Wav2Vec2Processor = None + Wav2Vec2Model = None + Wav2Vec2PreTrainedModel = None + TRANSFORMERS_AVAILABLE = False + +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + + +class TransformersNotAvailableError(RuntimeError): + """Exception raised when transformers is required but not available.""" + + pass + + +def is_transformers_available(): + """ + Check if the transformers package is available. + + Returns: + bool: True if transformers is available, False otherwise. + """ + return TRANSFORMERS_AVAILABLE class RegressionHead(nn.Module): @@ -71,53 +103,134 @@ def forward( return hidden_states, logits -def w2v2_emo_dim_setup( - model_tag="default", model_path=None, model_config=None, use_gpu=False -): - if use_gpu: - device = "cuda" - else: - device = "cpu" - if model_path is not None and model_config is not None: - model = EmotionModel.from_pretrained( - pretrained_model_name_or_path=model_path, config=model_config - ).to(device) - else: - if model_tag == "default": - model_tag = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" - model = EmotionModel.from_pretrained(model_tag).to(device) - processor = Wav2Vec2Processor.from_pretrained( - "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" +class EmoVadMetric(BaseMetric): + """Dimensional emotion prediction metric using w2v2-how-to.""" + + def _setup(self): + """Initialize EmoVad-specific components.""" + if not TRANSFORMERS_AVAILABLE: + raise ImportError( + "transformers is not properly installed. " + "Please install transformers and retry" + ) + + self.model_tag = self.config.get("model_tag", "default") + self.model_path = self.config.get("model_path", None) + self.model_config = self.config.get("model_config", None) + self.use_gpu = self.config.get("use_gpu", False) + + self.device = "cuda" if self.use_gpu and torch.cuda.is_available() else "cpu" + + try: + self.model, self.processor = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize EmoVad model: {str(e)}") from e + + def _setup_model(self): + """Setup the EmoVad model.""" + if self.model_path is not None and self.model_config is not None: + model = EmotionModel.from_pretrained( + pretrained_model_name_or_path=self.model_path, config=self.model_config + ).to(self.device) + else: + if self.model_tag == "default": + model_tag = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" + else: + model_tag = self.model_tag + model = EmotionModel.from_pretrained(model_tag).to(self.device) + + processor = Wav2Vec2Processor.from_pretrained( + "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" + ) + + return model, processor + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate dimensional emotion (arousal, dominance, valence) of input audio samples. + + Args: + predictions: Audio signal to evaluate. + references: Not used for this metric. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing the dimensional emotion predictions. + """ + pred_x = predictions + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate input + if pred_x is None: + raise ValueError("Predicted signal must be provided") + + pred_x = np.asarray(pred_x) + + # NOTE(jiatong): only work for 16000 Hz + if fs != 16000: + pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) + + pred_x = self.processor(pred_x, sampling_rate=16000) + pred_x = pred_x["input_values"][0] + pred_x = pred_x.reshape(1, -1) + pred_x = torch.from_numpy(pred_x).to(self.device) + + with torch.no_grad(): + avd_emo = self.model(pred_x)[1].squeeze(0).cpu().numpy() + + arousal, dominance, valence = avd_emo + arousal = arousal.item() + dominance = dominance.item() + valence = valence.item() + + return { + "arousal_emo_vad": arousal, + "valence_emo_vad": valence, + "dominance_emo_vad": dominance, + } + + def get_metadata(self) -> MetricMetadata: + """Return EmoVad metric metadata.""" + return MetricMetadata( + name="emo_vad", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["transformers", "torch", "librosa", "numpy"], + description="Dimensional emotion prediction (arousal, valence, dominance) using w2v2-how-to", + paper_reference="https://github.com/audeering/w2v2-how-to", + implementation_source="https://github.com/audeering/w2v2-how-to", + ) + + +def register_emo_vad_metric(registry): + """Register EmoVad metric with the registry.""" + metric_metadata = MetricMetadata( + name="emo_vad", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["transformers", "torch", "librosa", "numpy"], + description="Dimensional emotion prediction (arousal, valence, dominance) using w2v2-how-to", + paper_reference="https://github.com/audeering/w2v2-how-to", + implementation_source="https://github.com/audeering/w2v2-how-to", ) - emo_utils = {"model": model, "processor": processor, "device": device} - return emo_utils - - -def dim_emo_pred(emo_utils, pred_x, fs): - """Calculate dimensional emotion (arousal, dominance, valence) of input audio samples. - - Args: - model (w2v2-how-to): The loaded EMO2VEC model. - pred_x (np.ndarray): Predicted audio signal. - fs (int): Sampling rate. - - Returns: - dict: Dictionary containing the dimensional emotion predictions. - """ - # NOTE(jiatong): only work for 16000 Hz - if fs != 16000: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - pred_x = emo_utils["processor"](pred_x, sampling_rate=16000) - pred_x = pred_x["input_values"][0] - pred_x = pred_x.reshape(1, -1) - pred_x = torch.from_numpy(pred_x).to(emo_utils["device"]) - with torch.no_grad(): - avd_emo = emo_utils["model"](pred_x)[1].squeeze(0).cpu().numpy() - - return {"aro_val_dom_emo": avd_emo} + registry.register(EmoVadMetric, metric_metadata, aliases=["EmoVad", "emo_vad"]) if __name__ == "__main__": a = np.random.random(16000) - emo_utils = w2v2_emo_dim_setup() - print(f"metrics: {dim_emo_pred(emo_utils, a, 16000)}") + + # Test the new class-based metric + config = {"use_gpu": False} + metric = EmoVadMetric(config) + metadata = {"sample_rate": 16000} + score = metric.compute(a, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/emotion.py b/versa/utterance_metrics/emotion.py deleted file mode 100644 index a4945d3..0000000 --- a/versa/utterance_metrics/emotion.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2024 Jiatong Shi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Module for emotion similarity metrics using EMO2VEC.""" - -import logging -import os -from pathlib import Path - -import librosa -import numpy as np - -logger = logging.getLogger(__name__) - -try: - import emo2vec_versa - from emo2vec_versa.emo2vec_class import EMO2VEC -except ImportError: - logger.info( - "emo2vec is not installed. Please install the package via " - "`tools/install_emo2vec.sh`" - ) - EMO2VEC = None - - -def emo2vec_setup(model_tag="default", model_path=None, use_gpu=False): - """Set up EMO2VEC model for emotion embedding extraction. - - Args: - model_tag (str, optional): Model tag. Defaults to "default". - model_path (str, optional): Path to model weights. Defaults to None. - use_gpu (bool, optional): Whether to use GPU. Defaults to False. - - Returns: - EMO2VEC: The loaded model. - - Raises: - ImportError: If emo2vec_versa is not installed. - ValueError: If model_tag is unknown. - FileNotFoundError: If model file is not found. - """ - if EMO2VEC is None: - raise ImportError( - "emo2vec_versa not found. Please install from tools/installers" - ) - - if model_path is not None: - model = EMO2VEC(model_path, use_gpu=use_gpu) - else: - if model_tag == "default" or model_tag == "base": - model_path = ( - Path(os.path.abspath(emo2vec_versa.__file__)).parent - / "emotion2vec_base.pt" - ) - else: - raise ValueError(f"Unknown model_tag for emo2vec: {model_tag}") - - # check if model exists - if not model_path.exists(): - raise FileNotFoundError(f"Model file not found: {model_path}") - - model = EMO2VEC(checkpoint_dir=str(model_path), use_gpu=use_gpu) - return model - - -def emo_sim(model, pred_x, gt_x, fs): - """Calculate emotion similarity between two audio samples. - - Args: - model (EMO2VEC): The loaded EMO2VEC model. - pred_x (np.ndarray): Predicted audio signal. - gt_x (np.ndarray): Ground truth audio signal. - fs (int): Sampling rate. - - Returns: - dict: Dictionary containing the emotion similarity score. - """ - # NOTE(jiatong): only work for 16000 Hz - if fs != 16000: - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=16000) - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - - embedding_gen = model.extract_feature(pred_x, fs=16000) - embedding_gt = model.extract_feature(gt_x, fs=16000) - similarity = np.dot(embedding_gen, embedding_gt) / ( - np.linalg.norm(embedding_gen) * np.linalg.norm(embedding_gt) - ) - return {"emotion_similarity": similarity} - - -if __name__ == "__main__": - a = np.random.random(16000) - b = np.random.random(16000) - model = emo2vec_setup() - print(f"metrics: {emo_sim(model, a, b, 16000)}") diff --git a/versa/utterance_metrics/nisqa.py b/versa/utterance_metrics/nisqa.py index 84fee55..78bb12a 100644 --- a/versa/utterance_metrics/nisqa.py +++ b/versa/utterance_metrics/nisqa.py @@ -3,165 +3,226 @@ # Copyright 2025 Jiatong Shi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +"""Module for NISQA speech quality assessment metrics.""" + +import logging +import warnings +from typing import Dict, Any, Optional, Union + import librosa import numpy as np import torch import versa.utterance_metrics.nisqa_utils.nisqa_lib as NL - -def nisqa_model_setup(nisqa_model_path=None, use_gpu=False): - """ - Setup the NISQA model for evaluation. - Args: - nisqa_model_path (str): Path to the NISQA model checkpoint. - use_gpu (bool): If True, use GPU for computation. Default is False. - - Returns: - model: The loaded NISQA model. - - Raises: - ValueError: If the model path is not provided or the checkpoint is invalid. - """ - - # Check if GPU is available - if use_gpu and not torch.cuda.is_available(): - raise RuntimeError("GPU is not available. Please set use_gpu=False.") - - # Set device - if use_gpu: - device = "cuda" - else: - device = "cpu" - # Check if the model path is provided - if nisqa_model_path is None: - raise ValueError("NISQA model path must be provided.") - - checkpoint = torch.load(nisqa_model_path, map_location="cpu") - args = checkpoint.get("args", None) - if args is None: - raise ValueError( - "Model checkpoint does not contain the required arguments. Might due to a wrong checkpoint." +logger = logging.getLogger(__name__) + +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + + +class NisqaMetric(BaseMetric): + """NISQA speech quality assessment metric.""" + + TARGET_FS = 48000 # NISQA model's expected sampling rate + + def _setup(self): + """Initialize NISQA-specific components.""" + self.nisqa_model_path = self.config.get("nisqa_model_path") + self.use_gpu = self.config.get("use_gpu", False) + + if not self.nisqa_model_path: + raise ValueError("NISQA model path must be provided in config") + + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize NISQA model: {str(e)}") from e + + def _setup_model(self): + """Setup the NISQA model.""" + # Check if GPU is available + if self.use_gpu and not torch.cuda.is_available(): + raise RuntimeError("GPU is not available. Please set use_gpu=False.") + + # Set device + device = "cuda" if self.use_gpu else "cpu" + + # Suppress PyTorch config registration warnings during model loading + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="Skipping config registration for" + ) + checkpoint = torch.load(self.nisqa_model_path, map_location="cpu") + + args = checkpoint.get("args", None) + if args is None: + raise ValueError( + "Model checkpoint does not contain the required arguments. Might due to a wrong checkpoint." + ) + + if args["model"] == "NISQA_DIM": + args["dim"] = True + args["csv_mos_train"] = None # column names hardcoded for dim models + args["csv_mos_val"] = None + else: + args["dim"] = False + + if args["model"] == "NISQA_DE": + args["double_ended"] = True + else: + args["double_ended"] = False + args["csv_ref"] = None + + # Load Model + model_args = { + "ms_seg_length": args["ms_seg_length"], + "ms_n_mels": args["ms_n_mels"], + "cnn_model": args["cnn_model"], + "cnn_c_out_1": args["cnn_c_out_1"], + "cnn_c_out_2": args["cnn_c_out_2"], + "cnn_c_out_3": args["cnn_c_out_3"], + "cnn_kernel_size": args["cnn_kernel_size"], + "cnn_dropout": args["cnn_dropout"], + "cnn_pool_1": args["cnn_pool_1"], + "cnn_pool_2": args["cnn_pool_2"], + "cnn_pool_3": args["cnn_pool_3"], + "cnn_fc_out_h": args["cnn_fc_out_h"], + "td": args["td"], + "td_sa_d_model": args["td_sa_d_model"], + "td_sa_nhead": args["td_sa_nhead"], + "td_sa_pos_enc": args["td_sa_pos_enc"], + "td_sa_num_layers": args["td_sa_num_layers"], + "td_sa_h": args["td_sa_h"], + "td_sa_dropout": args["td_sa_dropout"], + "td_lstm_h": args["td_lstm_h"], + "td_lstm_num_layers": args["td_lstm_num_layers"], + "td_lstm_dropout": args["td_lstm_dropout"], + "td_lstm_bidirectional": args["td_lstm_bidirectional"], + "td_2": args["td_2"], + "td_2_sa_d_model": args["td_2_sa_d_model"], + "td_2_sa_nhead": args["td_2_sa_nhead"], + "td_2_sa_pos_enc": args["td_2_sa_pos_enc"], + "td_2_sa_num_layers": args["td_2_sa_num_layers"], + "td_2_sa_h": args["td_2_sa_h"], + "td_2_sa_dropout": args["td_2_sa_dropout"], + "td_2_lstm_h": args["td_2_lstm_h"], + "td_2_lstm_num_layers": args["td_2_lstm_num_layers"], + "td_2_lstm_dropout": args["td_2_lstm_dropout"], + "td_2_lstm_bidirectional": args["td_2_lstm_bidirectional"], + "pool": args["pool"], + "pool_att_h": args["pool_att_h"], + "pool_att_dropout": args["pool_att_dropout"], + } + + if args["double_ended"]: + model_args.update( + { + "de_align": args["de_align"], + "de_align_apply": args["de_align_apply"], + "de_fuse_dim": args["de_fuse_dim"], + "de_fuse": args["de_fuse"], + } + ) + + if args["model"] == "NISQA": + model = NL.NISQA(**model_args) + elif args["model"] == "NISQA_DIM": + model = NL.NISQA_DIM(**model_args) + elif args["model"] == "NISQA_DE": + model = NL.NISQA_DE(**model_args) + else: + raise NotImplementedError("Model not available") + + # Load weights + missing_keys, unexpected_keys = model.load_state_dict( + checkpoint["model_state_dict"], strict=True ) - - if args["model"] == "NISQA_DIM": - args["dim"] = True - args["csv_mos_train"] = None # column names hardcoded for dim models - args["csv_mos_val"] = None - else: - args["dim"] = False - - if args["model"] == "NISQA_DE": - args["double_ended"] = True - else: - args["double_ended"] = False - args["csv_ref"] = None - - # Load Model - model_args = { - "ms_seg_length": args["ms_seg_length"], - "ms_n_mels": args["ms_n_mels"], - "cnn_model": args["cnn_model"], - "cnn_c_out_1": args["cnn_c_out_1"], - "cnn_c_out_2": args["cnn_c_out_2"], - "cnn_c_out_3": args["cnn_c_out_3"], - "cnn_kernel_size": args["cnn_kernel_size"], - "cnn_dropout": args["cnn_dropout"], - "cnn_pool_1": args["cnn_pool_1"], - "cnn_pool_2": args["cnn_pool_2"], - "cnn_pool_3": args["cnn_pool_3"], - "cnn_fc_out_h": args["cnn_fc_out_h"], - "td": args["td"], - "td_sa_d_model": args["td_sa_d_model"], - "td_sa_nhead": args["td_sa_nhead"], - "td_sa_pos_enc": args["td_sa_pos_enc"], - "td_sa_num_layers": args["td_sa_num_layers"], - "td_sa_h": args["td_sa_h"], - "td_sa_dropout": args["td_sa_dropout"], - "td_lstm_h": args["td_lstm_h"], - "td_lstm_num_layers": args["td_lstm_num_layers"], - "td_lstm_dropout": args["td_lstm_dropout"], - "td_lstm_bidirectional": args["td_lstm_bidirectional"], - "td_2": args["td_2"], - "td_2_sa_d_model": args["td_2_sa_d_model"], - "td_2_sa_nhead": args["td_2_sa_nhead"], - "td_2_sa_pos_enc": args["td_2_sa_pos_enc"], - "td_2_sa_num_layers": args["td_2_sa_num_layers"], - "td_2_sa_h": args["td_2_sa_h"], - "td_2_sa_dropout": args["td_2_sa_dropout"], - "td_2_lstm_h": args["td_2_lstm_h"], - "td_2_lstm_num_layers": args["td_2_lstm_num_layers"], - "td_2_lstm_dropout": args["td_2_lstm_dropout"], - "td_2_lstm_bidirectional": args["td_2_lstm_bidirectional"], - "pool": args["pool"], - "pool_att_h": args["pool_att_h"], - "pool_att_dropout": args["pool_att_dropout"], - } - - if args["double_ended"]: - model_args.update( - { - "de_align": args["de_align"], - "de_align_apply": args["de_align_apply"], - "de_fuse_dim": args["de_fuse_dim"], - "de_fuse": args["de_fuse"], - } + if missing_keys: + logger.warning("[NISQA] missing_keys: %s", missing_keys) + if unexpected_keys: + logger.warning("[NISQA] unexpected_keys: %s", unexpected_keys) + + model.args = args + model.device = device + return model + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate NISQA scores for speech quality assessment. + + Args: + predictions: Audio signal to be evaluated. + references: Not used for NISQA (single-ended metric). + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing NISQA scores. + """ + pred_x = predictions + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + + pred_x = np.asarray(pred_x) + + # Resample if necessary + if fs != self.TARGET_FS: + pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=self.TARGET_FS) + fs = self.TARGET_FS + + # Evaluate the NISQA score + with torch.no_grad(): + metrics = NL.versa_eval_mos( + [pred_x], self.model, 1, self.model.device, num_workers=0 + ) + + final_result = {} + for metrics_key in metrics.keys(): + # Check if the metric is a list and take the first element for batch=1 + final_result["nisqa_" + metrics_key] = metrics[metrics_key][0][0] + + return final_result + + def get_metadata(self) -> MetricMetadata: + """Return NISQA metric metadata.""" + return MetricMetadata( + name="nisqa", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["torch", "librosa", "numpy"], + description="NISQA speech quality assessment metric", + paper_reference="https://github.com/gabrielmittag/NISQA", + implementation_source="https://github.com/gabrielmittag/NISQA", ) - if args["model"] == "NISQA": - model = NL.NISQA(**model_args) - elif args["model"] == "NISQA_DIM": - model = NL.NISQA_DIM(**model_args) - elif args["model"] == "NISQA_DE": - model = NL.NISQA_DE(**model_args) - else: - raise NotImplementedError("Model not available") - - # Load weights - missing_keys, unexpected_keys = model.load_state_dict( - checkpoint["model_state_dict"], strict=True - ) - if missing_keys: - print("[NISQA] missing_keys:") - print(missing_keys) - if unexpected_keys: - print("[NISQA] unexpected_keys:") - print(unexpected_keys) - model.args = args - model.device = device - return model - - -def nisqa_metric(nisqa_model, pred_x, fs): - """ - Calculate the NISQA score for a given audio signal. - - Args: - nisqa_model: The NISQA model to use for evaluation. - pred_x (np.ndarray): The audio signal to be evaluated (1D array). - fs (int): The sampling rate of the audio signal in Hz. - - Returns: - dict: A dictionary containing the NISQA score and other metrics. - """ - model_sr = 48e3 # NISQA model's expected sampling rate - if fs != model_sr: - # Resample the audio signal to the model's expected sampling rate - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=model_sr) - fs = model_sr - - # Evaluate the NISQA score - with torch.no_grad(): - metrics = NL.versa_eval_mos( - [pred_x], nisqa_model, 1, nisqa_model.device, num_workers=0 - ) - - final_result = {} - for metrics_key in metrics.keys(): - # Check if the metric is a list and take the first element for batch=1 - final_result["nisqa_" + metrics_key] = metrics[metrics_key][0][0] - return final_result +def register_nisqa_metric(registry): + """Register NISQA metric with the registry.""" + metric_metadata = MetricMetadata( + name="nisqa", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["torch", "librosa", "numpy"], + description="NISQA speech quality assessment metric", + paper_reference="https://github.com/gabrielmittag/NISQA", + implementation_source="https://github.com/gabrielmittag/NISQA", + ) + registry.register( + NisqaMetric, + metric_metadata, + aliases=["Nisqa", "nisqa"], + ) if __name__ == "__main__": diff --git a/versa/utterance_metrics/nomad.py b/versa/utterance_metrics/nomad.py index 3033c6b..1b6761c 100644 --- a/versa/utterance_metrics/nomad.py +++ b/versa/utterance_metrics/nomad.py @@ -1,63 +1,150 @@ #!/usr/bin/env python3 +# Copyright 2024 Jiatong Shi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import logging +"""Module for NOMAD speech quality assessment metrics.""" -logger = logging.getLogger(__name__) +import logging +from typing import Dict, Any, Optional, Union import librosa import numpy as np import torch +logger = logging.getLogger(__name__) + +# Handle optional nomad dependency try: from nomad_versa import Nomad + + NOMAD_AVAILABLE = True except ImportError: - logger.info( + logger.warning( "nomad is not installed. Please use `tools/install_nomad.sh` to install" ) Nomad = None + NOMAD_AVAILABLE = False +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType -def nomad_setup(use_gpu=False, cache_dir="versa_cache/nomad_pt-models"): - if use_gpu: - device = "cuda" - else: - device = "cpu" - if Nomad is None: - raise ModuleNotFoundError( - "nomad is not installed. Please use `tools/install_nomad.sh` to install" - ) +class NomadNotAvailableError(RuntimeError): + """Exception raised when nomad is required but not available.""" - return Nomad(device=device, cache_dir=cache_dir) + pass -def nomad(model, pred_x, gt_x, fs): +def is_nomad_available(): """ - Reference: - A. Ragano, J. Skoglund and A. Hines, - "NOMAD: Unsupervised Learning of Perceptual Embeddings For Speech Enhancement and Non-Matching Reference Audio Quality Assessment," - ICASSP 2024 - 2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Seoul, Korea, Republic of, 2024, pp. 1011-1015 - Codebase: - https://github.com/alessandroragano/nomad + Check if the nomad package is available. + Returns: + bool: True if nomad is available, False otherwise. """ - - # NOTE(hyejin): current model only have 16k options - if fs != 16000: - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=16000) - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - - return { - "nomad": model.predict(nmr=gt_x, deg=pred_x), - } + return NOMAD_AVAILABLE + + +class NomadMetric(BaseMetric): + """NOMAD speech quality assessment metric.""" + + TARGET_FS = 16000 # NOMAD model's expected sampling rate + + def _setup(self): + """Initialize NOMAD-specific components.""" + if not NOMAD_AVAILABLE: + raise ImportError( + "nomad is not installed. Please use `tools/install_nomad.sh` to install" + ) + + self.use_gpu = self.config.get("use_gpu", False) + self.cache_dir = self.config.get("model_cache", "versa_cache/nomad_pt-models") + + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize NOMAD model: {str(e)}") from e + + def _setup_model(self): + """Setup the NOMAD model.""" + device = "cuda" if self.use_gpu else "cpu" + + if Nomad is None: + raise ModuleNotFoundError( + "nomad is not installed. Please use `tools/install_nomad.sh` to install" + ) + + return Nomad(device=device, cache_dir=self.cache_dir) + + def compute( + self, predictions: Any, references: Any, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate NOMAD score for speech quality assessment. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing NOMAD score. + """ + pred_x = predictions + gt_x = references + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + if gt_x is None: + raise ValueError("Reference signal must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + # Resample if necessary (NOMAD only supports 16kHz) + if fs != self.TARGET_FS: + gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=self.TARGET_FS) + pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=self.TARGET_FS) + + return { + "nomad": self.model.predict(nmr=gt_x, deg=pred_x), + } + + def get_metadata(self) -> MetricMetadata: + """Return NOMAD metric metadata.""" + return MetricMetadata( + name="nomad", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["nomad_versa", "torch", "librosa", "numpy"], + description="NOMAD: Unsupervised Learning of Perceptual Embeddings For Speech Enhancement and Non-Matching Reference Audio Quality Assessment", + paper_reference="https://ieeexplore.ieee.org/document/10447047", + implementation_source="https://github.com/alessandroragano/nomad", + ) -if __name__ == "__main__": - a = np.random.random(16000) - b = np.random.random(16000) - nomad_model = nomad_setup(use_gpu=True) - fs = 16000 - nomad_score = nomad(nomad_model, a, b, fs) - print(nomad_score) +def register_nomad_metric(registry): + """Register NOMAD metric with the registry.""" + metric_metadata = MetricMetadata( + name="nomad", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["nomad_versa", "torch", "librosa", "numpy"], + description="NOMAD: Unsupervised Learning of Perceptual Embeddings For Speech Enhancement and Non-Matching Reference Audio Quality Assessment", + paper_reference="https://ieeexplore.ieee.org/document/10447047", + implementation_source="https://github.com/alessandroragano/nomad", + ) + registry.register( + NomadMetric, + metric_metadata, + aliases=["Nomad", "nomad"], + ) diff --git a/versa/utterance_metrics/noresqa.py b/versa/utterance_metrics/noresqa.py index b9391c0..998e39e 100644 --- a/versa/utterance_metrics/noresqa.py +++ b/versa/utterance_metrics/noresqa.py @@ -3,129 +3,277 @@ # Copyright 2024 Jiatong Shi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +"""Module for NORESQA speech quality assessment metrics.""" import logging import os import sys +import warnings +from typing import Dict, Any, Optional, Union import librosa import numpy as np import torch +import torch.nn as nn +from urllib.request import urlretrieve logger = logging.getLogger(__name__) -from urllib.request import urlretrieve +# Handle optional dependencies +try: + import fairseq -import torch.nn as nn + FAIRSEQ_AVAILABLE = True +except ImportError: + logger.warning( + "fairseq is not installed. Please use `tools/install_fairseq.sh` to install" + ) + fairseq = None + FAIRSEQ_AVAILABLE = False +# Setup NORESQA path base_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "../../tools/Noresqa") ) sys.path.insert(0, base_path) +from noresqa_model import NORESQA +from noresqa_utils import ( + feats_loading, + model_prediction_noresqa, + model_prediction_noresqa_mos, +) -try: - import fairseq -except ImportError: - logger.info( - "fairseq is not installed. Please use `tools/install_fairseq.sh` to install" - ) +NORESQA_AVAILABLE = True try: - from model import NORESQA - from utils import ( + from noresqa_model import NORESQA + from noresqa_utils import ( feats_loading, model_prediction_noresqa, model_prediction_noresqa_mos, ) + NORESQA_AVAILABLE = True except ImportError: - logger.info( + logger.warning( "noresqa is not installed. Please use `tools/install_noresqa.sh` to install" ) - Noresqa = None - - -def noresqa_model_setup( - model_tag="default", - metric_type=0, - cache_dir="versa_cache/noresqa_model", - use_gpu=False, -): - if use_gpu: - device = "cuda" - else: - device = "cpu" - - if model_tag == "default": - - if not os.path.isdir(cache_dir): - print("Creating checkpoints directory") - os.makedirs(cache_dir) - - url_w2v = "https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt" - w2v_path = os.path.join(cache_dir, "wav2vec_small.pt") - if not os.path.isfile(w2v_path): - print("Downloading wav2vec 2.0 started") - urlretrieve(url_w2v, w2v_path) - print("wav2vec 2.0 download completed") - - model = NORESQA( - output=40, output2=40, metric_type=metric_type, config_path=w2v_path - ) + NORESQA = None + feats_loading = None + model_prediction_noresqa = None + model_prediction_noresqa_mos = None + NORESQA_AVAILABLE = False - if metric_type == 0: - model_checkpoint_path = "{}/models/model_noresqa.pth".format(base_path) - state = torch.load(model_checkpoint_path, map_location="cpu")["state_base"] +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType - elif metric_type == 1: - model_checkpoint_path = "{}/models/model_noresqa_mos.pth".format(base_path) - state = torch.load(model_checkpoint_path, map_location="cpu")["state_dict"] - pretrained_dict = {} - for k, v in state.items(): - if "module" in k: - pretrained_dict[k.replace("module.", "")] = v - else: - pretrained_dict[k] = v - model_dict = model.state_dict() - model_dict.update(pretrained_dict) - model.load_state_dict(pretrained_dict) +class NoresqaNotAvailableError(RuntimeError): + """Exception raised when noresqa is required but not available.""" + + pass - # change device as needed - model.to(device) - model.device = device - model.eval() - sfmax = nn.Softmax(dim=1) +def is_noresqa_available(): + """ + Check if the noresqa package is available. - else: - raise NotImplementedError + Returns: + bool: True if noresqa is available, False otherwise. + """ + return NORESQA_AVAILABLE and FAIRSEQ_AVAILABLE - return model +class NoresqaMetric(BaseMetric): + """NORESQA speech quality assessment metric.""" -def noresqa_metric(model, gt_x, pred_x, fs, metric_type=1): - # NOTE(hyejin): only work for 16000 Hz - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=16000) - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - nmr_feat, test_feat = feats_loading(pred_x, gt_x, noresqa_or_noresqaMOS=metric_type) - test_feat = torch.from_numpy(test_feat).float().to(model.device).unsqueeze(0) - nmr_feat = torch.from_numpy(nmr_feat).float().to(model.device).unsqueeze(0) + TARGET_FS = 16000 # NORESQA model's expected sampling rate - with torch.no_grad(): - if metric_type == 0: - noresqa_pout, noresqa_qout = model_prediction_noresqa( - test_feat, nmr_feat, model + def _setup(self): + """Initialize NORESQA-specific components.""" + if not NORESQA_AVAILABLE: + raise ImportError( + "noresqa is not installed. Please use `tools/install_noresqa.sh` to install" + ) + if not FAIRSEQ_AVAILABLE: + raise ImportError( + "fairseq is not installed. Please use `tools/install_fairseq.sh` to install" ) - return {"noresqa_score": noresqa_pout} - elif metric_type == 1: - mos_score = model_prediction_noresqa_mos(test_feat, nmr_feat, model) - return {"noresqa_score": mos_score} + self.model_tag = self.config.get("model_tag", "default") + self.metric_type = self.config.get( + "metric_type", 1 + ) # 0: NORESQA-score, 1: NORESQA-MOS + self.cache_dir = self.config.get("cache_dir", "versa_cache/noresqa_model") + self.use_gpu = self.config.get("use_gpu", False) + + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize NORESQA model: {str(e)}") from e + + def _setup_model(self): + """Setup the NORESQA model.""" + device = "cuda" if self.use_gpu else "cpu" + + if self.model_tag == "default": + if not os.path.isdir(self.cache_dir): + logger.info("Creating checkpoints directory") + os.makedirs(self.cache_dir) + + url_w2v = "https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt" + w2v_path = os.path.join(self.cache_dir, "wav2vec_small.pt") + if not os.path.isfile(w2v_path): + logger.info("Downloading wav2vec 2.0 started") + urlretrieve(url_w2v, w2v_path) + logger.info("wav2vec 2.0 download completed") + + model = NORESQA( + output=40, + output2=40, + metric_type=self.metric_type, + config_path=w2v_path, + ) -if __name__ == "__main__": - a = np.random.random(16000) - b = np.random.random(16000) - model = noresqa_model_setup(use_gpu=True) - print("metrics: {}".format(noresqa_metric(model, a, b, 16000))) + if self.metric_type == 0: + model_checkpoint_path = "{}/models/model_noresqa.pth".format(base_path) + # Suppress PyTorch config registration warnings during model loading + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="Skipping config registration for" + ) + state = torch.load(model_checkpoint_path, map_location="cpu")[ + "state_base" + ] + elif self.metric_type == 1: + model_checkpoint_path = "{}/models/model_noresqa_mos.pth".format( + base_path + ) + # Suppress PyTorch config registration warnings during model loading + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="Skipping config registration for" + ) + state = torch.load(model_checkpoint_path, map_location="cpu")[ + "state_dict" + ] + else: + raise ValueError(f"Invalid metric_type: {self.metric_type}") + + pretrained_dict = {} + for k, v in state.items(): + if "module" in k: + pretrained_dict[k.replace("module.", "")] = v + else: + pretrained_dict[k] = v + model_dict = model.state_dict() + model_dict.update(pretrained_dict) + model.load_state_dict(pretrained_dict) + + # change device as needed + model.to(device) + model.device = device + model.eval() + + else: + raise NotImplementedError(f"Model tag '{self.model_tag}' not implemented") + + return model + + def compute( + self, predictions: Any, references: Any, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate NORESQA score for speech quality assessment. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing NORESQA score. + """ + pred_x = predictions + gt_x = references + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + if gt_x is None: + raise ValueError("Reference signal must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + # Resample to 16kHz (NORESQA only works with 16kHz) + if fs != self.TARGET_FS: + gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=self.TARGET_FS) + pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=self.TARGET_FS) + + nmr_feat, test_feat = feats_loading( + pred_x, gt_x, noresqa_or_noresqaMOS=self.metric_type + ) + test_feat = ( + torch.from_numpy(test_feat).float().to(self.model.device).unsqueeze(0) + ) + nmr_feat = torch.from_numpy(nmr_feat).float().to(self.model.device).unsqueeze(0) + + with torch.no_grad(): + if self.metric_type == 0: + noresqa_pout, noresqa_qout = model_prediction_noresqa( + test_feat, nmr_feat, self.model + ) + return {"noresqa_score": noresqa_pout} + elif self.metric_type == 1: + mos_score = model_prediction_noresqa_mos( + test_feat, nmr_feat, self.model + ) + return {"noresqa_mos": mos_score} + else: + raise ValueError(f"Invalid metric_type: {self.metric_type}") + + def get_metadata(self) -> MetricMetadata: + """Return NORESQA metric metadata.""" + metric_name = "noresqa_mos" if self.metric_type == 1 else "noresqa_score" + description = "NORESQA-MOS" if self.metric_type == 1 else "NORESQA-score" + + return MetricMetadata( + name=metric_name, + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["fairseq", "torch", "librosa", "numpy"], + description=f"{description}: Non-matching reference based speech quality assessment", + paper_reference="https://arxiv.org/abs/2104.09411", + implementation_source="https://github.com/facebookresearch/NORESQA", + ) + + +def register_noresqa_metric(registry): + """Register NORESQA metric with the registry.""" + # Register both metric types + for metric_type, metric_name in [(0, "noresqa_score"), (1, "noresqa_mos")]: + description = "NORESQA-MOS" if metric_type == 1 else "NORESQA-score" + + metric_metadata = MetricMetadata( + name=metric_name, + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["fairseq", "torch", "librosa", "numpy"], + description=f"{description}: Non-matching reference based speech quality assessment", + paper_reference="https://arxiv.org/abs/2104.09411", + implementation_source="https://github.com/facebookresearch/NORESQA", + ) + registry.register( + NoresqaMetric, + metric_metadata, + aliases=[f"Noresqa{metric_type}", metric_name], + ) diff --git a/versa/utterance_metrics/owsm_lid.py b/versa/utterance_metrics/owsm_lid.py index b9286b6..b9eece4 100644 --- a/versa/utterance_metrics/owsm_lid.py +++ b/versa/utterance_metrics/owsm_lid.py @@ -3,39 +3,154 @@ # Copyright 2024 Jiatong Shi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import os +"""Module for OWSM Language Identification (LID) metrics.""" + +import logging +from typing import Dict, Any, Optional, Union import librosa import numpy as np -from espnet2.bin.s2t_inference_language import Speech2Language - - -def owsm_lid_model_setup(model_tag="default", nbest=3, use_gpu=False): - if use_gpu: - device = "cuda" - else: - device = "cpu" - if model_tag == "default": - model_tag = "espnet/owsm_v3.1_ebf" - model = Speech2Language.from_pretrained( - model_tag=model_tag, - device=device, - nbest=nbest, + +logger = logging.getLogger(__name__) + +# Handle optional espnet2 dependency +try: + from espnet2.bin.s2t_inference_language import Speech2Language + + ESPNET2_AVAILABLE = True +except ImportError: + logger.warning( + "espnet2 is not properly installed. " "Please install espnet2 and retry" ) + Speech2Language = None + ESPNET2_AVAILABLE = False + +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + + +class Espnet2NotAvailableError(RuntimeError): + """Exception raised when espnet2 is required but not available.""" + + pass + + +def is_espnet2_available(): + """ + Check if the espnet2 package is available. + + Returns: + bool: True if espnet2 is available, False otherwise. + """ + return ESPNET2_AVAILABLE + + +class OwsmLidMetric(BaseMetric): + """OWSM Language Identification (LID) metric.""" + + TARGET_FS = 16000 # OWSM model's expected sampling rate + + def _setup(self): + """Initialize OWSM LID-specific components.""" + if not ESPNET2_AVAILABLE: + raise ImportError( + "espnet2 is not properly installed. Please install espnet2 and retry" + ) - return model + self.model_tag = self.config.get("model_tag", "default") + self.nbest = self.config.get("nbest", 3) + self.use_gpu = self.config.get("use_gpu", False) + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize OWSM LID model: {str(e)}") from e -def language_id(model, pred_x, fs): - # NOTE(jiatong): only work for 16000 Hz - if fs != 16000: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) + def _setup_model(self): + """Setup the OWSM LID model.""" + device = "cuda" if self.use_gpu else "cpu" - result = model(pred_x) - return {"language": result} + if self.model_tag == "default": + model_tag = "espnet/owsm_v3.1_ebf" + else: + model_tag = self.model_tag + + model = Speech2Language.from_pretrained( + model_tag=model_tag, + device=device, + nbest=self.nbest, + ) + + return model + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate language identification for speech. + + Args: + predictions: Audio signal to be evaluated. + references: Not used for LID (single-ended metric). + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing language identification result. + """ + pred_x = predictions + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + + pred_x = np.asarray(pred_x) + + # Resample if necessary (OWSM only works with 16kHz) + if fs != self.TARGET_FS: + pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=self.TARGET_FS) + + result = self.model(pred_x) + return {"language": result} + + def get_metadata(self) -> MetricMetadata: + """Return OWSM LID metric metadata.""" + return MetricMetadata( + name="lid", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.LIST, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["espnet2", "librosa", "numpy"], + description="OWSM Language Identification (LID) for speech", + paper_reference="https://arxiv.org/abs/2309.16588", + implementation_source="https://github.com/espnet/espnet", + ) + + +def register_owsm_lid_metric(registry): + """Register OWSM LID metric with the registry.""" + metric_metadata = MetricMetadata( + name="lid", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.LIST, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["espnet2", "librosa", "numpy"], + description="OWSM Language Identification (LID) for speech", + paper_reference="https://arxiv.org/abs/2309.16588", + implementation_source="https://github.com/espnet/espnet", + ) + registry.register( + OwsmLidMetric, + metric_metadata, + aliases=["OwsmLid", "lid", "language_id"], + ) if __name__ == "__main__": a = np.random.random(16000) - model = owsm_lid_model_setup() - print("metrics: {}".format(language_id(model, a, 16000))) + model = OwsmLidMetric() + print("metrics: {}".format(model.compute(a, None, {"sample_rate": 16000}))) diff --git a/versa/utterance_metrics/pam.py b/versa/utterance_metrics/pam.py index c02eb0c..9ef5271 100644 --- a/versa/utterance_metrics/pam.py +++ b/versa/utterance_metrics/pam.py @@ -28,9 +28,35 @@ import numpy as np import torch.nn.functional as F +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + +# Handle optional dependencies +try: + from versa.utterance_metrics.pam_utils.clap import CLAP + + PAM_AVAILABLE = True +except ImportError: + logger.warning( + "PAM dependencies are not installed. Please install required dependencies" + ) + CLAP = None + PAM_AVAILABLE = False + # Constants HF_REPO = "microsoft/msclap" CLAP_VERSION = "CLAP_weights_2023.pth" + + +def is_pam_available(): + """ + Check if the PAM dependencies are available. + + Returns: + bool: True if PAM dependencies are available, False otherwise. + """ + return PAM_AVAILABLE + + PAM_PROMPTS = [ "the sound is clear and clean.", "the sound is noisy and with artifacts.", @@ -240,140 +266,197 @@ def evaluate(self, audio_tensor: torch.Tensor) -> float: return pam_score -def load_audio( - audio_file: Union[str, torch.Tensor], sample_rate: int, repro: bool = True -) -> torch.Tensor: - """ - Load and preprocess audio file. +class PamMetric(BaseMetric): + """PAM (Perceptual Audio Metric) for audio quality assessment.""" - Args: - audio_file: Path to audio file or audio tensor - sample_rate: Sample rate of the input audio - repro: If True, use reproducible processing (taking first 7 seconds) + TARGET_FS = 44100 # PAM model's expected sampling rate - Returns: - Processed audio tensor - """ - # Load audio file if path is provided - if isinstance(audio_file, str): - audio, sample_rate = torchaudio.load(audio_file) - else: - audio = audio_file.clone() # Create a copy to avoid modifying the original - - # Ensure audio is a FloatTensor - audio = torch.FloatTensor(audio) - - # Resample audio if needed - if sample_rate != RESAMPLE_RATE: - resampler = T.Resample(sample_rate, RESAMPLE_RATE) - audio = resampler(audio) - - # Convert to mono if stereo - if audio.shape[0] > 1: - audio = torch.mean(audio, dim=0, keepdim=True) - - # Reshape to 1D - audio = audio.reshape(-1) - - # Process audio to be exactly AUDIO_DURATION seconds - if SAMPLES >= audio.shape[0]: - # Audio is shorter than required duration, repeat to match - repeat_factor = int(np.ceil(SAMPLES / audio.shape[0])) - audio = audio.repeat(repeat_factor) - # Trim to exact length - audio = audio[:SAMPLES] - else: - # Audio is longer than required duration - if repro: - # Take first AUDIO_DURATION seconds - audio = audio[:SAMPLES] - else: - # Take chunks of AUDIO_DURATION seconds plus remaining portion - cutoff = int(np.floor(audio.shape[0] / SAMPLES)) - initial_audio = audio[: cutoff * SAMPLES] - - remaining = audio[cutoff * SAMPLES :] - if remaining.shape[0] > 0: - # If remaining is non-empty, take the last AUDIO_DURATION seconds - remaining = ( - audio[-SAMPLES:] - if remaining.shape[0] <= SAMPLES - else remaining[:SAMPLES] - ) - audio = torch.cat([initial_audio, remaining]) - else: - audio = initial_audio + def _setup(self): + """Initialize PAM-specific components.""" + if not PAM_AVAILABLE: + raise ImportError( + "PAM dependencies are not installed. Please install required dependencies" + ) - return audio + self.repro = self.config.get("repro", True) + self.cache_dir = self.config.get("cache_dir", "versa_cache/pam") + self.use_gpu = self.config.get("use_gpu", False) + + # Extract model configuration from config + model_config = { + "text_model": self.config.get("text_model", "gpt2"), + "text_len": self.config.get("text_len", 77), + "transformer_embed_dim": self.config.get("transformer_embed_dim", 768), + "audioenc_name": self.config.get("audioenc_name", "HTSAT"), + "out_emb": self.config.get("out_emb", 768), + "sampling_rate": self.config.get("sampling_rate", 44100), + "duration": self.config.get("duration", 7), + "fmin": self.config.get("fmin", 50), + "fmax": self.config.get("fmax", 8000), + "n_fft": self.config.get("n_fft", 1024), + "hop_size": self.config.get("hop_size", 320), + "mel_bins": self.config.get("mel_bins", 64), + "window_size": self.config.get("window_size", 1024), + "d_proj": self.config.get("d_proj", 1024), + "temperature": self.config.get("temperature", 0.003), + "num_classes": self.config.get("num_classes", 527), + "batch_size": self.config.get("batch_size", 1024), + "demo": self.config.get("demo", False), + } + try: + self.model = PAM(model_config=model_config, use_cuda=self.use_gpu) + except Exception as e: + raise RuntimeError(f"Failed to initialize PAM model: {str(e)}") from e -def pam_model_setup(model_config: Dict[str, Any], use_gpu: bool = False) -> PAM: - """ - Initialize PAM model with given configuration. + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate PAM score for audio quality assessment. - Args: - model_config: Model configuration dictionary - use_gpu: Whether to use GPU for computation + Args: + predictions: Audio signal to be evaluated. + references: Not used for PAM (single-ended metric). + metadata: Optional metadata containing sample_rate. - Returns: - Initialized PAM model - """ - model = PAM(model_config=model_config, use_cuda=use_gpu) - return model + Returns: + dict: Dictionary containing PAM score. + """ + pred_x = predictions + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") -def pam_metric( - model: PAM, - pred_x: Union[str, torch.Tensor, np.ndarray], - gt_x: Optional[Union[str, torch.Tensor, np.ndarray]] = None, - fs: int = 16000, -) -> Dict[str, float]: - """ - Compute PAM metric for given audio. + # Convert numpy array to tensor if needed + if isinstance(pred_x, np.ndarray): + pred_x = torch.FloatTensor(pred_x) - Args: - model: PAM model - pred_x: Predicted audio (file path or tensor) - gt_x: Ground truth audio (unused, kept for API compatibility) - fs: Sample rate of the input audio + # Load and preprocess audio + audio = self._load_audio(pred_x, fs) - Returns: - Dictionary containing PAM score - """ - # Convert numpy array to tensor if needed - if isinstance(pred_x, np.ndarray): - pred_x = torch.FloatTensor(pred_x) - - # Load and preprocess audio - audio = load_audio(pred_x, fs, repro=True) - - # Ensure audio has batch dimension - if len(audio.shape) < 2: - audio = audio.unsqueeze(0) - - # Compute PAM score - pam_score = model.evaluate(audio) - - return {"pam_score": pam_score} - - -if __name__ == "__main__": - # Example usage - a = np.random.random(16000) - - # Load configuration from YAML file - try: - with open("egs/separate_metrics/pam.yaml", "r", encoding="utf-8") as f: - config = yaml.safe_load(f)[0] - except (FileNotFoundError, yaml.YAMLError) as e: - print(f"Error loading configuration: {e}") - sys.exit(1) - - # Initialize model and compute metric - try: - model = pam_model_setup(config, use_gpu=torch.cuda.is_available()) - result = pam_metric(model, a, fs=16000) - print(f"PAM score: {result['pam_score']:.4f}") - except Exception as e: - print(f"Error computing PAM metric: {e}") - sys.exit(1) + # Ensure audio has batch dimension + if len(audio.shape) < 2: + audio = audio.unsqueeze(0) + + # Compute PAM score + pam_score = self.model.evaluate(audio) + + return {"pam_score": pam_score} + + def _load_audio( + self, audio_file: Union[str, torch.Tensor], sample_rate: int + ) -> torch.Tensor: + """ + Load and preprocess audio file. + + Args: + audio_file: Path to audio file or audio tensor + sample_rate: Sample rate of the input audio + + Returns: + Processed audio tensor + """ + # Load audio file if path is provided + if isinstance(audio_file, str): + audio, sample_rate = torchaudio.load(audio_file) + else: + audio = audio_file.clone() # Create a copy to avoid modifying the original + + # Ensure audio is a FloatTensor + audio = torch.FloatTensor(audio) + + # Resample audio if needed + if sample_rate != self.TARGET_FS: + resampler = T.Resample(sample_rate, self.TARGET_FS) + audio = resampler(audio) + + # Convert to mono if stereo + if audio.shape[0] > 1: + audio = torch.mean(audio, dim=0, keepdim=True) + + # Reshape to 1D + audio = audio.reshape(-1) + + # Process audio to be exactly AUDIO_DURATION seconds + samples = self.TARGET_FS * AUDIO_DURATION + if samples >= audio.shape[0]: + # Audio is shorter than required duration, repeat to match + repeat_factor = int(np.ceil(samples / audio.shape[0])) + audio = audio.repeat(repeat_factor) + # Trim to exact length + audio = audio[:samples] + else: + # Audio is longer than required duration + if self.repro: + # Take first AUDIO_DURATION seconds + audio = audio[:samples] + else: + # Take chunks of AUDIO_DURATION seconds plus remaining portion + cutoff = int(np.floor(audio.shape[0] / samples)) + initial_audio = audio[: cutoff * samples] + + remaining = audio[cutoff * samples :] + if remaining.shape[0] > 0: + # If remaining is non-empty, take the last AUDIO_DURATION seconds + remaining = ( + audio[-samples:] + if remaining.shape[0] <= samples + else remaining[:samples] + ) + audio = torch.cat([initial_audio, remaining]) + else: + audio = initial_audio + + return audio + + def get_metadata(self) -> MetricMetadata: + """Return PAM metric metadata.""" + return MetricMetadata( + name="pam", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=[ + "torch", + "torchaudio", + "transformers", + "huggingface_hub", + "numpy", + ], + description="PAM: Prompting Audio-Language Models for Audio Quality Assessment", + paper_reference="https://arxiv.org/abs/2309.07317", + implementation_source="https://github.com/soham97/PAM", + ) + + +def register_pam_metric(registry): + """Register PAM metric with the registry.""" + metric_metadata = MetricMetadata( + name="pam", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=[ + "torch", + "torchaudio", + "transformers", + "huggingface_hub", + "numpy", + ], + description="PAM: Prompting Audio-Language Models for Audio Quality Assessment", + paper_reference="https://arxiv.org/abs/2309.07317", + implementation_source="https://github.com/soham97/PAM", + ) + registry.register( + PamMetric, + metric_metadata, + aliases=["Pam", "pam", "perceptual_audio_metric"], + ) diff --git a/versa/utterance_metrics/pesq_score.py b/versa/utterance_metrics/pesq_score.py index b6062b5..3dbb94f 100644 --- a/versa/utterance_metrics/pesq_score.py +++ b/versa/utterance_metrics/pesq_score.py @@ -3,46 +3,148 @@ # Copyright 2024 Jiatong Shi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import logging +"""Module for PESQ (Perceptual Evaluation of Speech Quality) metrics.""" -logger = logging.getLogger(__name__) +import logging +from typing import Dict, Any, Optional, Union import librosa import numpy as np -from pesq import pesq +logger = logging.getLogger(__name__) + +# Handle optional pesq dependency try: from pesq import pesq + + PESQ_AVAILABLE = True except ImportError: - raise ImportError("Please install pesq and retry: pip install pesq") - - -def pesq_metric(pred_x, gt_x, fs): - try: - if fs == 8000: - pesq_value = pesq(8000, gt_x, pred_x, "nb") - elif fs < 16000: - logging.info("not support fs {}, resample to 8khz".format(fs)) - new_gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=8000) - new_pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=8000) - pesq_value = pesq(16000, new_gt_x, new_pred_x, "nb") - elif fs == 16000: - pesq_value = pesq(16000, gt_x, pred_x, "wb") - else: - logging.info("not support fs {}, resample to 16khz".format(fs)) - new_gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=16000) - new_pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - pesq_value = pesq(16000, new_gt_x, new_pred_x, "wb") - except BaseException: - logging.warning( - "Error from pesq calculation. Please check the audio (likely due to silence)" + logger.warning( + "pesq is not properly installed. Please install pesq and retry: pip install pesq" + ) + pesq = None + PESQ_AVAILABLE = False + +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + + +class PesqNotAvailableError(RuntimeError): + """Exception raised when pesq is required but not available.""" + + pass + + +def is_pesq_available(): + """ + Check if the pesq package is available. + + Returns: + bool: True if pesq is available, False otherwise. + """ + return PESQ_AVAILABLE + + +class PesqMetric(BaseMetric): + """PESQ (Perceptual Evaluation of Speech Quality) metric.""" + + def _setup(self): + """Initialize PESQ-specific components.""" + if not PESQ_AVAILABLE: + raise ImportError( + "pesq is not properly installed. Please install pesq and retry: pip install pesq" + ) + + def compute( + self, predictions: Any, references: Any, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate PESQ score for speech quality assessment. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing PESQ score. + """ + pred_x = predictions + gt_x = references + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + if gt_x is None: + raise ValueError("Reference signal must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + try: + if fs == 8000: + pesq_value = pesq(8000, gt_x, pred_x, "nb") + elif fs < 16000: + logger.info("not support fs {}, resample to 8khz".format(fs)) + new_gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=8000) + new_pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=8000) + pesq_value = pesq(8000, new_gt_x, new_pred_x, "nb") + elif fs == 16000: + pesq_value = pesq(16000, gt_x, pred_x, "wb") + else: + logger.info("not support fs {}, resample to 16khz".format(fs)) + new_gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=16000) + new_pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) + pesq_value = pesq(16000, new_gt_x, new_pred_x, "wb") + except BaseException: + logger.warning( + "Error from pesq calculation. Please check the audio (likely due to silence)" + ) + pesq_value = 0.0 + + return {"pesq": pesq_value} + + def get_metadata(self) -> MetricMetadata: + """Return PESQ metric metadata.""" + return MetricMetadata( + name="pesq", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["pesq", "librosa", "numpy"], + description="PESQ: Perceptual Evaluation of Speech Quality", + paper_reference="https://www.itu.int/rec/T-REC-P.862", + implementation_source="https://github.com/ludlows/python-pesq", ) - pesq_value = 0.0 - return {"pesq": pesq_value} + + +def register_pesq_metric(registry): + """Register PESQ metric with the registry.""" + metric_metadata = MetricMetadata( + name="pesq", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["pesq", "librosa", "numpy"], + description="PESQ: Perceptual Evaluation of Speech Quality", + paper_reference="https://www.itu.int/rec/T-REC-P.862", + implementation_source="https://github.com/ludlows/python-pesq", + ) + registry.register( + PesqMetric, + metric_metadata, + aliases=["Pesq", "pesq", "perceptual_evaluation_speech_quality"], + ) if __name__ == "__main__": a = np.random.random(16000) b = np.random.random(16000) - scores = pesq_metric(a, b, 16000) + metric = PesqMetric() + scores = metric.compute(a, b, metadata={"sample_rate": 16000}) print(scores) diff --git a/versa/utterance_metrics/srmr.py b/versa/utterance_metrics/srmr.py index 7543d7a..c48392d 100644 --- a/versa/utterance_metrics/srmr.py +++ b/versa/utterance_metrics/srmr.py @@ -2,6 +2,7 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import logging +from typing import Dict, Any, Optional logger = logging.getLogger(__name__) @@ -13,39 +14,98 @@ logger.info("srmr is not installed. Please use `tools/install_srmr.sh` to install") srmr = None +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType -def srmr_metric( - pred_x, - fs, - n_cochlear_filters=23, - low_freq=125, - min_cf=4, - max_cf=128, - fast=True, - norm=False, -): - if srmr is None: - raise ImportError( - # Error message if SRMRpy is not installed + +class SRMRMetric(BaseMetric): + """Speech-to-Reverberation Modulation energy Ratio (SRMR) metric.""" + + def _setup(self): + """Initialize SRMR-specific components.""" + if srmr is None: + raise ImportError( + "srmr is not installed. Please use `tools/install_srmr.sh` to install" + ) + + # Set default parameters from config + self.n_cochlear_filters = self.config.get("n_cochlear_filters", 23) + self.low_freq = self.config.get("low_freq", 125) + self.min_cf = self.config.get("min_cf", 4) + self.max_cf = self.config.get("max_cf", 128) + self.fast = self.config.get("fast", True) + self.norm = self.config.get("norm", False) + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, float]: + """Compute the SRMR score.""" + pred_x = predictions + sample_rate = metadata.get("sample_rate", 16000) if metadata else 16000 + + srmr_score = srmr( + pred_x, + sample_rate, + n_cochlear_filters=self.n_cochlear_filters, + low_freq=self.low_freq, + min_cf=self.min_cf, + max_cf=self.max_cf, + fast=self.fast, + norm=self.norm, ) - srmr_score = srmr( - pred_x, - fs, - n_cochlear_filters=n_cochlear_filters, - low_freq=low_freq, - min_cf=min_cf, - max_cf=max_cf, - fast=fast, - norm=norm, - ) - return { - "srmr": srmr_score, - } + return { + "srmr": srmr_score, + } + def get_metadata(self) -> MetricMetadata: + """Return SRMR metric metadata.""" + return MetricMetadata( + name="srmr", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["srmrpy"], + description="Speech-to-Reverberation Modulation energy Ratio (SRMR) for speech quality assessment", + paper_reference="http://www.individual.utoronto.ca/falkt/falk/pdf/FalkChan_TASLP2010.pdf", + implementation_source="https://github.com/shimhz/SRMRpy.git", + ) + + +# Auto-registration function +def register_srmr_metric(registry): + """Register SRMR metric with the registry.""" + metric_metadata = MetricMetadata( + name="srmr", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["srmrpy"], + description="Speech-to-Reverberation Modulation energy Ratio (SRMR) for speech quality assessment", + paper_reference="http://www.individual.utoronto.ca/falkt/falk/pdf/FalkChan_TASLP2010.pdf", + implementation_source="https://github.com/shimhz/SRMRpy.git", + ) + registry.register(SRMRMetric, metric_metadata, aliases=["SRMR"]) -if __name__ == "__main__": +if __name__ == "__main__": a = np.random.random(16000) - score = srmr_metric(a, 16000) - print(score) + + # Test the new class-based metric + config = { + "n_cochlear_filters": 23, + "low_freq": 125, + "min_cf": 4, + "max_cf": 128, + "fast": True, + "norm": False, + } + metric = SRMRMetric(config) + metadata = {"sample_rate": 16000} + score = metric.compute(a, metadata=metadata) + print("SRMR", score)