diff --git a/README.md b/README.md index fbeaa46..9a63268 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,9 @@ If the speech XAI functionalities are needed, then follow these steps: 2. install whisperX with `pip install git+https://github.com/m-bain/whisperx.git` 3. install system-wide [ffmpeg](https://ffmpeg.org/download.html). If you have no sudo rights, you can try with `conda install conda-forge::ffmpeg` +### Testing +For detailed instructions on setting up your environment and running tests, please see our [Testing Guidelines](TESTING.md). + ### Explain & Benchmark diff --git a/TESTING.md b/TESTING.md new file mode 100644 index 0000000..24cdbde --- /dev/null +++ b/TESTING.md @@ -0,0 +1,47 @@ +# Testing + +To ensure the quality and functionality of our code, we use automated tests. + +## Installation +from PyPI: +```bash +pip install pytest +``` + +## Running Tests +We use pytest for our tests. Below are the commands for running tests in different scopes: + + +### All Tests +Run all tests for both speech and text with: +```bash +pytest +``` + +### Specific Test Files +Run tests for text processing only: +```bash +pytest tests/test_text.py +``` + +Run tests for speech processing only: +```bash +pytest tests/test_speech.py +``` + +### Specific Test Methods +Run a specific test method by specifying the test file and method name +(replacing the `test_text.py` with the desired test file and `test_method_name` by the desited test method): +```bash +pytest tests/test_text.py::test_method_name +``` + +### Clear Cache +We use some caching in our tests. If you encounter issues that might be related to cached test results or configurations, you can clear the pytest cache with: +```bash +pytest --cache-clear +``` +This command removes all items from the cache, ensuring that your next test run is completely clean. + +## Troubleshooting Common Issues +If tests behave unexpectedly or fail after changes, consider clearing the pytest cache or re-running the tests to verify if the issue persists. Always ensure that your environment matches the required configurations as specified in our setup guidelines. \ No newline at end of file diff --git a/ferret/benchmark_speech.py b/ferret/benchmark_speech.py index fc4ce5c..2fcee1d 100644 --- a/ferret/benchmark_speech.py +++ b/ferret/benchmark_speech.py @@ -157,6 +157,9 @@ def explain( """ Explain the prediction of the model. Returns the importance of each segment in the audio. + + Note: the `target_class` argument specifies the ID of the target + class. """ explainer_args = dict() # TODO UNIFY THE INPUT FORMAT diff --git a/ferret/explainers/explanation_speech/gradient_speech_explainer.py b/ferret/explainers/explanation_speech/gradient_speech_explainer.py index e8ff2b7..20e9528 100644 --- a/ferret/explainers/explanation_speech/gradient_speech_explainer.py +++ b/ferret/explainers/explanation_speech/gradient_speech_explainer.py @@ -106,7 +106,7 @@ def compute_explanation( # if word_timestamps is None: # # Transcribe audio - word_timestamps = audio.transcription + # word_timestamps = audio.transcription # Compute gradient importance for each target label # This also handles the multilabel scenario as for FSC diff --git a/ferret/explainers/explanation_speech/loo_speech_explainer.py b/ferret/explainers/explanation_speech/loo_speech_explainer.py index 5588d8e..41aa0c6 100644 --- a/ferret/explainers/explanation_speech/loo_speech_explainer.py +++ b/ferret/explainers/explanation_speech/loo_speech_explainer.py @@ -31,6 +31,9 @@ def remove_words( - silence - white noise - pink noise + + Note: in all the manipulations, the sample rate remains that of + the input `audio`! """ ## Load audio as pydub.AudioSegment @@ -61,6 +64,8 @@ def compute_explanation( ) -> ExplanationSpeech: """ Computes the importance of each word in the audio. + + `target` class should be an integer identifying the class ID. """ ## Get modified audio by leaving a single word out and the words @@ -86,6 +91,8 @@ def compute_explanation( targets = target_class else: + # If no target class is passed, the explanation is computed for + # the predicted class. if n_labels > 1: # Multilabel scenario as for FSC targets = [ diff --git a/ferret/explainers/explanation_speech/utils_removal.py b/ferret/explainers/explanation_speech/utils_removal.py index 20ed538..6541fdf 100644 --- a/ferret/explainers/explanation_speech/utils_removal.py +++ b/ferret/explainers/explanation_speech/utils_removal.py @@ -57,6 +57,11 @@ def remove_word(audio, word, removal_type: str = "nothing"): - white noise - pink noise + WARNING: if `word["start"] * 1000 - a` is negative, the audio is actually + traversed FROM SOME POINT UNTIL ITS END (like `l[-10:]` + actually takes the last 10 entries of the list `l`). Therefore if + the difference is negative, we effectively use `a=0`. + Args: audio (pydub.AudioSegment): audio word: word to remove with its start and end times @@ -65,22 +70,33 @@ def remove_word(audio, word, removal_type: str = "nothing"): a, b = 100, 40 - before_word_audio = audio[: word["start"] * 1000 - a] - after_word_audio = audio[word["end"] * 1000 + b :] - word_duration = (word["end"] * 1000 - word["start"] * 1000) + a + b + # Convert from seconds (as returned by WhisperX) to milliseconds (as + # required to index PyDub `AudioSegment` objects). + word_start_ms = word["start"] * 1000 + word_end_ms = word["end"] * 1000 + + # If we risk reading the audio segment from the end (difference is + # negative) set the offset `a` to zero to avoid that. + if word_start_ms - a < 0: + a = 0 + + before_word_audio = audio[:word_start_ms - a] + after_word_audio = audio[word_end_ms + b :] + word_duration = (word_end_ms - word_start_ms) + a + b if removal_type == "nothing": replace_word_audio = AudioSegment.empty() + elif removal_type == "silence": replace_word_audio = AudioSegment.silent(duration=word_duration) elif removal_type == "white noise": - sound_path = (os.path.join(os.path.dirname(__file__), "white_noise.mp3"),) + sound_path = os.path.join(os.path.dirname(__file__), "white_noise.mp3") + replace_word_audio = AudioSegment.from_mp3(sound_path)[:word_duration] - # display(audio_removed) elif removal_type == "pink noise": - sound_path = (os.path.join(os.path.dirname(__file__), "pink_noise.mp3"),) + sound_path = os.path.join(os.path.dirname(__file__), "pink_noise.mp3") replace_word_audio = AudioSegment.from_mp3(sound_path)[:word_duration] audio_removed = before_word_audio + replace_word_audio + after_word_audio diff --git a/ferret/modeling/speech_model_helpers/model_helper_er.py b/ferret/modeling/speech_model_helpers/model_helper_er.py index a5f4e4b..35bf214 100644 --- a/ferret/modeling/speech_model_helpers/model_helper_er.py +++ b/ferret/modeling/speech_model_helpers/model_helper_er.py @@ -54,8 +54,21 @@ def _predict( ## Predict logits with torch.no_grad(): + # Some feature encoders return the input tensor(s) under the + # `input_values` key, other under the `input_features` one. + if 'input_values' in inputs.keys(): + input_features = inputs['input_values'].to(self.device) + elif 'input_features' in inputs.keys(): + input_features = inputs['input_features'].to(self.device) + else: + raise Exception( + 'Input features not found in inputs dict neither under' + ' the `input_values` key, nor under the `input_features`' + ' one' + ) + logits = ( - self.model(inputs.input_values.to(self.device)) + self.model(input_features) .logits.detach() .cpu() # .numpy() diff --git a/ferret/speechxai_utils.py b/ferret/speechxai_utils.py index 827dc30..40fdcb2 100644 --- a/ferret/speechxai_utils.py +++ b/ferret/speechxai_utils.py @@ -28,7 +28,7 @@ def __init__( if isinstance(audio_path_or_array, str): self.array, self.current_sr = librosa.load( - audio_path_or_array, sr=None, dtype=np.float32 + audio_path_or_array, sr=None, dtype=np.float32, mono=True ) elif isinstance(audio_path_or_array, np.ndarray): if current_sr is None: @@ -65,8 +65,8 @@ def resample(self, target_sr: int): Resample the audio to the target sampling rate. In place operation. """ self.array = librosa.resample( - self.array, orig_sr=self.current_sr, target_sr=target_sr - ) + self.array.ravel(), orig_sr=self.current_sr, target_sr=target_sr + ).reshape(-1, 1) self.current_sr = target_sr @staticmethod @@ -130,7 +130,7 @@ def transcribe_audio( ## Load whisperx model. TODO: we should definitely avoid loading the model for *every* sample to subscribe device_type = device.type - device_index = device.index + device_index = device.index if device.index is not None else 0 model_whisperx = whisperx.load_model( model_name_whisper, diff --git a/pyproject.toml b/pyproject.toml index 8709b80..2c08f79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ joblib = "^1.3.2" pytreebank = "^0.2.7" thermostat-datasets = "^1.1.0" ipython = "^8.22.2" +pytest = "^7.4.4" # Speech-XAI additional requirements to allow for `pip install ferret[speech]`. pydub = { version = "0.25.1", optional = true } audiomentations = { version = "0.34.1", optional = true } diff --git a/tests/data/sample_audio.wav b/tests/data/sample_audio.wav new file mode 100644 index 0000000..5d253ae Binary files /dev/null and b/tests/data/sample_audio.wav differ diff --git a/tests/test_explainers.py b/tests/test_explainers.py index aceb0f2..9d19226 100644 --- a/tests/test_explainers.py +++ b/tests/test_explainers.py @@ -197,3 +197,6 @@ def test_gradient_ner(self): explanation = exp(text, target="I-LOC", target_token="York") self.assertTrue("york" in [token.lower() for token in explanation.tokens]) self.assertEqual(explanation.target_pos_idx, 6) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_speech.py b/tests/test_speech.py new file mode 100644 index 0000000..37cec21 --- /dev/null +++ b/tests/test_speech.py @@ -0,0 +1,145 @@ +import pytest +import torch +import os +import numpy as np +import pandas as pd +from pydub import AudioSegment +from ferret import SpeechBenchmark +from ferret.explainers.explanation_speech.loo_speech_explainer import LOOSpeechExplainer +from ferret.explainers.explanation_speech.gradient_speech_explainer import ( + GradientSpeechExplainer, +) +from ferret.explainers.explanation_speech.lime_speech_explainer import ( + LIMESpeechExplainer, +) +from ferret.explainers.explanation_speech.paraling_speech_explainer import ( + ParalinguisticSpeechExplainer, +) +from scipy.io.wavfile import write +from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor + + + +# ================================================================ +# = Fixtures creation audio sample to use throughout the testing = +# ================================================================ +@pytest.fixture(scope="module") +def sample_audio_file(): + return os.path.join(os.path.dirname(__file__), 'data', 'sample_audio.wav') + + +@pytest.fixture(scope="module") +def benchmark(): + model = Wav2Vec2ForSequenceClassification.from_pretrained( + "superb/wav2vec2-base-superb-ic" + ) + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + "superb/wav2vec2-base-superb-ic" + ) + return SpeechBenchmark(model, feature_extractor) + +# ========== +# = Tests = +# ========== + +def test_initialization_benchmark(benchmark): + assert benchmark.model is not None + assert benchmark.feature_extractor is not None + assert isinstance(benchmark, SpeechBenchmark) + +def test_explainer_types(benchmark): + for explainer_name, explainer in benchmark.explainers.items(): + assert explainer is not None + assert explainer_name in ['LOO', 'Gradient', 'GradientXInput', 'LIME', 'perturb_paraling'] + assert isinstance(explainer, (LOOSpeechExplainer, GradientSpeechExplainer, LIMESpeechExplainer, ParalinguisticSpeechExplainer)) + + +def test_audio_transcription(benchmark, sample_audio_file): + audio = AudioSegment.from_wav(sample_audio_file) + sr = audio.frame_rate + transcription = benchmark.transcribe(sample_audio_file, current_sr=sr) + + assert transcription[0] is not None + assert transcription[0] == ' Turn up the bedroom heat.' + +def test_prediction(benchmark, sample_audio_file): + audio = AudioSegment.from_wav(sample_audio_file) + audio_array = np.array(audio.get_array_of_samples()).astype(np.float32) + audio_array /= np.max(np.abs(audio_array)) + predictions = benchmark.predict([audio_array]) + + assert predictions is not None + assert len(predictions) == 3 + action_probs, object_probs, location_probs = benchmark.predict([audio_array]) + + assert len(action_probs) == 1 + assert len(object_probs) == 1 + assert len(location_probs) == 1 + assert action_probs[0].shape == (6,) + assert object_probs[0].shape == (14,) + assert location_probs[0].shape == (4,) + +@pytest.mark.parametrize("methodology", ["LOO", "Gradient", "LIME", "perturb_paraling"]) +def test_explain_method(benchmark, sample_audio_file, methodology): + explanations = benchmark.explain( + audio_path_or_array=sample_audio_file, + current_sr=16000, + methodology=methodology, + ) + + assert explanations is not None + + if methodology != "perturb_paraling": + assert hasattr(explanations, 'scores') + assert hasattr(explanations, 'features') + assert len(explanations.scores) > 0 + assert len(explanations.features) > 0 + else: + assert isinstance(explanations, list) + assert len(explanations) > 0 + for explanation in explanations: + assert hasattr(explanation, 'scores') + assert hasattr(explanation, 'features') + + +def test_explain_features(benchmark, sample_audio_file): + explanations = benchmark.explain( + audio_path_or_array=sample_audio_file, + current_sr=16000, + methodology='LOO', + ) + + expected_features = ['Turn', 'up', 'the', 'bedroom', 'heat.'] + assert explanations.features == expected_features + +def test_invalid_audio_file(benchmark): + with pytest.raises(Exception): + benchmark.explain( + audio_path_or_array='non_existent_file.wav', + current_sr=16000, + methodology='LOO', + ) + +def test_silence_audio(benchmark): + silent_audio = np.zeros(int(16000 * 1)) # 1 second of silent audio at 16kHz + explanations = benchmark.explain( + audio_path_or_array=silent_audio, + current_sr=16000, + methodology='LOO', + ) + assert explanations is not None + assert explanations.scores.shape == (3,0) + assert len(explanations.features) == 0 + +def test_explain_variations(benchmark, sample_audio_file): + perturbation_types = ['time stretching', 'pitch shifting', 'noise'] + variations_table = benchmark.explain_variations( + audio_path_or_array=sample_audio_file, + current_sr=16000, + perturbation_types=perturbation_types + ) + assert isinstance(variations_table, dict) + assert all(pt in variations_table for pt in perturbation_types) + for pt, df in variations_table.items(): + assert isinstance(df, pd.DataFrame) + assert not df.empty \ No newline at end of file diff --git a/tests/test_text.py b/tests/test_text.py new file mode 100644 index 0000000..34f2d0d --- /dev/null +++ b/tests/test_text.py @@ -0,0 +1,259 @@ +import pytest +import math +from transformers import ( + AutoModelForSequenceClassification, + AutoModelForTokenClassification, + AutoTokenizer, +) +from ferret import ( + Benchmark, + LIMEExplainer, + SHAPExplainer, + GradientExplainer, + IntegratedGradientExplainer, +) +from ferret.evaluators.faithfulness_measures import ( + AOPC_Comprehensiveness_Evaluation, + AOPC_Sufficiency_Evaluation, + TauLOO_Evaluation, +) +from ferret.evaluators.plausibility_measures import ( + AUPRC_PlausibilityEvaluation, + Tokenf1_PlausibilityEvaluation, + TokenIOU_PlausibilityEvaluation, +) +from ferret.evaluators.class_measures import AOPC_Comprehensiveness_Evaluation_by_class +from ferret.modeling.text_helpers import SequenceClassificationHelper + +DEFAULT_EXPLAINERS_NUM = 6 +DEFAULT_EVALUATORS_NUM = 6 +DEFAULT_EVALUATORS_BY_CLASS_NUM = 1 + +TASK_NAME_MAP = { + "lvwerra/distilbert-imdb": "text-classification", + "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli": "nli", + "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli": "zero-shot-text-classification", + "Babelscape/wikineural-multilingual-ner": "ner", +} +explainer_init_extra_args = { + GradientExplainer: {"multiply_by_inputs": True}, + IntegratedGradientExplainer: {"multiply_by_inputs": True}, +} + + +# ============================================================ +# = Fixtures creation to initalize each model and tokenizer = +# ============================================================ +@pytest.fixture( + scope="module", + params=[ + ("lvwerra/distilbert-imdb", AutoModelForSequenceClassification), + ( + "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli", + AutoModelForSequenceClassification, + ), + ("MoritzLaurer/mDeBERTa-v3-base-mnli-xnli", AutoModelForSequenceClassification), + ("Babelscape/wikineural-multilingual-ner", AutoModelForTokenClassification), + ], + ids=["textclass", "nli", "zeroshot", "ner"], +) +def model_and_tokenizer(request): + model_cls = request.param[1] + model = model_cls.from_pretrained(request.param[0]) + tokenizer = AutoTokenizer.from_pretrained(request.param[0]) + task_name = TASK_NAME_MAP[request.param[0]] + + return model, tokenizer, task_name + + +@pytest.fixture(scope="module") +def explainer(request, model_and_tokenizer): + model, tokenizer, task_name = model_and_tokenizer + kwargs = explainer_init_extra_args.get(request.param, {}) + return request.param(model, tokenizer, task_name=task_name, **kwargs) + + +@pytest.fixture +def model_tokenizer_ner(): + model = AutoModelForTokenClassification.from_pretrained( + "Babelscape/wikineural-multilingual-ner" + ) + tokenizer = AutoTokenizer.from_pretrained("Babelscape/wikineural-multilingual-ner") + return model, tokenizer + + +# ================================================================ +# = Fixture for all fixtures (initialization of the benchmarks) = +# ================================================================ +@pytest.fixture +def all_benchmarks(model_and_tokenizer): + model, tokenizer, task_name = model_and_tokenizer + return Benchmark(model, tokenizer, task_name=task_name) + + +# ========= +# = Tests = +# ========= + + +# Setup and Initialization Checks +def test_initialization_benchmarks(all_benchmarks): + assert all_benchmarks.model is not None + assert all_benchmarks.tokenizer is not None + assert isinstance(all_benchmarks, Benchmark) + assert len(all_benchmarks.explainers) == DEFAULT_EXPLAINERS_NUM + assert len(all_benchmarks.evaluators) == DEFAULT_EVALUATORS_NUM + assert len(all_benchmarks.class_based_evaluators) == DEFAULT_EVALUATORS_BY_CLASS_NUM + + +def test_explainer_types(all_benchmarks): + assert ( + isinstance(e, SHAPExplainer) + or isinstance(e, LIMEExplainer) + or isinstance(e, GradientExplainer) + or isinstance(e, IntegratedGradientExplainer) + for e in all_benchmarks.explainers + ) + + +def test_evaluator_types(all_benchmarks): + expected_evaluator_types = [ + AOPC_Comprehensiveness_Evaluation, + AOPC_Sufficiency_Evaluation, + TauLOO_Evaluation, + AUPRC_PlausibilityEvaluation, + Tokenf1_PlausibilityEvaluation, + TokenIOU_PlausibilityEvaluation, + ] + assert ( + any(isinstance(ev, t) for t in expected_evaluator_types) + for ev in all_benchmarks.evaluators + ) + assert ( + isinstance(ev_class, AOPC_Comprehensiveness_Evaluation_by_class) + for ev_class in all_benchmarks.class_based_evaluators + ) + + +def test_helper_assignment(all_benchmarks): + for explainer in all_benchmarks.explainers: + assert explainer.helper == all_benchmarks.helper + + +def test_helper_override_warning(model_tokenizer_ner): + model_ner, tokenizer_ner = model_tokenizer_ner + explainer_with_helper = SHAPExplainer( + model_ner, tokenizer_ner, helper=SequenceClassificationHelper + ) + with pytest.warns(UserWarning, match="Overriding helper for explainer"): + Benchmark( + model_ner, + tokenizer_ner, + task_name="ner", + explainers=[explainer_with_helper], + ) + + +# Scoring Checks +def test_scoring_len_and_output(all_benchmarks, cache): + text = "The weather in London sucks" + labels = ["weather complaint", "traffic"] + if all_benchmarks.task_name == "zero-shot-text-classification": + score = all_benchmarks.score( + text, options=labels, return_probs=True, return_dict=True + ) + cache.set("zero-shot-score", score) + # caching the score of the zero-shot since it will be used later + expected_labels = labels + else: + score = all_benchmarks.score(text, return_dict=True) + expected_labels = list(all_benchmarks.targets.values()) + + if all_benchmarks.task_name == "ner": + assert all( + all(label in token_scores[1].keys() for label in expected_labels) + for token_scores in score.values() + ) + else: + assert all(label in score for label in expected_labels) + assert len(score) == len(expected_labels) + assert math.isclose(sum(score.values()), 1, abs_tol=0.01) + + +# Explainer Checks +@pytest.mark.parametrize( + "explainer", + [SHAPExplainer, LIMEExplainer, GradientExplainer, IntegratedGradientExplainer], + indirect=True, +) +@pytest.mark.parametrize( + "text, target, expected_tokens, expected_target_pos_idx, target_token, task", + [ + ( + "You look stunning!", + 1, + ["[CLS]", "you", "look", "stunning", "!", "[SEP]"], + 1, + None, + "text-classification", + ), + ( + "A tennis game with two females playing.", + "contradiction", + ["[CLS]", "▁A", "▁tennis", "▁game", "▁with"], + 2, + None, + "nli", + ), + ( + "I am John and I live in New York", + "I-LOC", + ["[CLS]","I","am","John","and","I","live","in", "New","York","[SEP]",], + 6, + "York", + "ner", + ), + ( + "The weather in London sucks", + "entailment", + ['[CLS]', '▁The', '▁weather', '▁in', '▁London', '▁', 'suck', 's', '[SEP]', '▁This', '▁is', '▁weather', '▁', 'complaint', '[SEP]'], + 0, + None, + "zero-shot-text-classification", + ), + ], + ids=["textclass_example", "nli_example", "ner_example", "zero_shot_example"], +) +def test_explainers( + explainer, + model_and_tokenizer, + text, + target, + expected_tokens, + expected_target_pos_idx, + target_token, + task, + cache, +): + _, _, model_task = model_and_tokenizer + if model_task != task: + pytest.skip(f"Skipping {model_task} as it does not match the task {task}") + + if task == "zero-shot-text-classification": + scores = cache.get("zero-shot-score", {}) + target_option = max(scores, key=scores.get) if scores else None + sep_token = '[SEP]' + text = [text + f" {sep_token} " + 'This is {}'.format(target_option)] + else: + target_option = None + + explanation = ( + explainer( + text, target=target, target_token=target_token, target_option=target_option + ) + if isinstance(explainer, (SHAPExplainer, LIMEExplainer)) + else explainer(text, target=target, target_token=target_token) + ) + # target_token is for NER and target_option is for zero-shot (remember in SHAP it is ignored) + assert explanation.tokens[: len(expected_tokens)] == expected_tokens + assert explanation.target_pos_idx == expected_target_pos_idx