From 9dc7be2b80cdb8dfb29d729d8ef7fc5ab32a6baa Mon Sep 17 00:00:00 2001 From: elronbandel Date: Tue, 1 Apr 2025 14:38:43 +0300 Subject: [PATCH 01/36] Add audio support Signed-off-by: elronbandel --- prepare/cards/minds_14.py | 75 +++++++++++ prepare/tasks/classification.py | 2 +- src/unitxt/audio_operators.py | 38 ++++++ src/unitxt/catalog/cards/minds_14.json | 126 ++++++++++++++++++ .../tasks/classification/multi_class.json | 2 +- src/unitxt/formats.py | 73 ++++++---- src/unitxt/image_operators.py | 2 + src/unitxt/schema.py | 3 + src/unitxt/serializers.py | 15 +++ src/unitxt/settings_utils.py | 1 + src/unitxt/templates.py | 2 + 11 files changed, 308 insertions(+), 31 deletions(-) create mode 100644 prepare/cards/minds_14.py create mode 100644 src/unitxt/audio_operators.py create mode 100644 src/unitxt/catalog/cards/minds_14.json diff --git a/prepare/cards/minds_14.py b/prepare/cards/minds_14.py new file mode 100644 index 0000000000..0b7b8c7794 --- /dev/null +++ b/prepare/cards/minds_14.py @@ -0,0 +1,75 @@ +from unitxt.audio_operators import ToAudio +from unitxt.blocks import LoadHF, Set, TaskCard +from unitxt.catalog import add_to_catalog +from unitxt.operators import MapInstanceValues, Rename +from unitxt.splitters import SplitRandomMix +from unitxt.test_utils.card import test_card + +classes = [ + "abroad", + "address", + "app_error", + "atm_limit", + "balance", + "business_loan", + "card_issues", + "cash_deposit", + "direct_debit", + "freeze", + "high_value_payment", + "joint_account", + "latest_transactions", + "pay_bill", +] + +card = TaskCard( + loader=LoadHF(path="PolyAI/minds14", name="en-US"), + preprocess_steps=[ + SplitRandomMix( + {"train": "train[90%]", "validation": "train[5%]", "test": "train[5%]"} + ), + MapInstanceValues(mappers={"intent_class": {str(i): label for i, label in enumerate(classes)}}), + Rename(field="intent_class", to_field="label"), + Set( + fields={ + "text_type": "sentence", + "type_of_class": "intent", + "classes": classes + } + ), + ToAudio(field="audio", to_field="text"), + + ], + task="tasks.classification.multi_class", + templates="templates.classification.multi_class.all", + __tags__={ + "annotations_creators": [ + "expert-generated", + "crowdsourced", + "machine-generated" + ], + "language_creators": [ + "crowdsourced", + "expert-generated" + ], + "language": [ + "en", "fr", "it", "es", "pt", "de", "nl", "ru", "pl", "cs", "ko", "zh" + ], + "license": "cc-by-4.0", + "multilinguality": "multilingual", + "size_categories": "10K Audio: + return { + "audio": value, + } + + +def audio_to_base64(audio_data): + """Convert a HuggingFace Audio instance to a base64-encoded WAV string. + + Args: + audio_data (dict): The Audio instance from HuggingFace datasets + Contains 'array', 'sampling_rate', and 'path' keys + + Returns: + str: Base64-encoded WAV audio + """ + import base64 + from io import BytesIO + + import soundfile as sf + # Create a BytesIO buffer to hold the WAV data + buffer = BytesIO() + + # Write the audio array to the buffer in WAV format + sf.write(buffer, audio_data["array"], audio_data["sampling_rate"], format="wav") + + # Get the bytes from the buffer + wav_bytes = buffer.getvalue() + + # Encode to base64 + return base64.b64encode(wav_bytes).decode("utf-8") diff --git a/src/unitxt/catalog/cards/minds_14.json b/src/unitxt/catalog/cards/minds_14.json new file mode 100644 index 0000000000..73a6a871d0 --- /dev/null +++ b/src/unitxt/catalog/cards/minds_14.json @@ -0,0 +1,126 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "PolyAI/minds14", + "name": "en-US" + }, + "preprocess_steps": [ + { + "__type__": "split_random_mix", + "mix": { + "train": "train[90%]", + "validation": "train[5%]", + "test": "train[5%]" + } + }, + { + "__type__": "map_instance_values", + "mappers": { + "intent_class": { + "0": "abroad", + "1": "address", + "2": "app_error", + "3": "atm_limit", + "4": "balance", + "5": "business_loan", + "6": "card_issues", + "7": "cash_deposit", + "8": "direct_debit", + "9": "freeze", + "10": "high_value_payment", + "11": "joint_account", + "12": "latest_transactions", + "13": "pay_bill" + } + } + }, + { + "__type__": "rename", + "field": "intent_class", + "to_field": "label" + }, + { + "__type__": "set", + "fields": { + "text_type": "sentence", + "type_of_class": "intent", + "classes": [ + "abroad", + "address", + "app_error", + "atm_limit", + "balance", + "business_loan", + "card_issues", + "cash_deposit", + "direct_debit", + "freeze", + "high_value_payment", + "joint_account", + "latest_transactions", + "pay_bill" + ] + } + }, + { + "__type__": "to_audio", + "field": "audio", + "to_field": "text" + } + ], + "task": "tasks.classification.multi_class", + "templates": "templates.classification.multi_class.all", + "__tags__": { + "annotations_creators": [ + "expert-generated", + "crowdsourced", + "machine-generated" + ], + "language_creators": [ + "crowdsourced", + "expert-generated" + ], + "language": [ + "en", + "fr", + "it", + "es", + "pt", + "de", + "nl", + "ru", + "pl", + "cs", + "ko", + "zh" + ], + "license": "cc-by-4.0", + "multilinguality": "multilingual", + "size_categories": "10K Union[str, List[Content]]: - # Regular expression to find tags with src attribute - img_tag_pattern = re.compile( - r"<" + f"{constants.image_tag}" + r'\s+[^>]*src=["\']([^"\']+)["\'][^>]*>', + image_tag = constants.image_tag + audio_tag = constants.audio_tag + + # Unified regex for both tags + tag_pattern = re.compile( + rf"<(?P{re.escape(image_tag)}|{re.escape(audio_tag)})\s+[^>]*src=[\"'](?P[^\"']+)[\"'][^>]*>", re.IGNORECASE, ) - # Find all matches of tags and their positions - matches = list(img_tag_pattern.finditer(text)) + matches = list(tag_pattern.finditer(text)) - # If no images are found, return the text as a plain string if not matches: return text contents: List[dict] = [] last_pos = 0 - # Process each match for match in matches: start, end = match.span() - img_url = match.group(1) + tag = match.group("tag").lower() + src = match.group("src") - # Add preceding text, if any + # Add preceding text if last_pos < start: contents.append({"type": "text", "text": text[last_pos:start]}) - # Add image content with a default detail level - if img_url.startswith("media/"): - image = dict_get(media, img_url[6:]) - data_url = image_to_data_url(image) - contents.append( - { - "type": "image_url", - "image_url": {"url": data_url, "detail": "low"}, - } - ) - else: - contents.append( - { - "type": "image_url", - "image_url": {"url": img_url, "detail": "low"}, - } - ) + is_local = src.startswith("media/") + media_key = src[6:] if is_local else None + + if tag == image_tag: + if is_local: + image = dict_get(media, media_key) + data_url = image_to_data_url(image) + else: + data_url = src + contents.append({ + "type": "image_url", + "image_url": {"url": data_url, "detail": "low"}, + }) + + elif tag == audio_tag: + if is_local: + audio = dict_get(media, media_key) + data_url = audio_to_base64(audio) + else: + data_url = src + contents.append({ + "type": "input_audio", + "input_audio": {"data": data_url, "format": "wav"}, + }) - # Update the last processed position last_pos = end - # Add any remaining text after the last image + # Add any trailing text if last_pos < len(text): contents.append({"type": "text", "text": text[last_pos:]}) @@ -466,6 +480,7 @@ def _format_instance_to_source( media, ) media["images"] = [] + media["audios"] = [] return chat diff --git a/src/unitxt/image_operators.py b/src/unitxt/image_operators.py index 53635a637c..2e8411a1f8 100644 --- a/src/unitxt/image_operators.py +++ b/src/unitxt/image_operators.py @@ -128,6 +128,8 @@ def process_instance_value(self, value: Any, instance: Dict[str, Any]) -> Image: } + + class ImageFieldOperator(FieldOperator, PillowMixin): @abstractmethod def process_image(self, image: Any): diff --git a/src/unitxt/schema.py b/src/unitxt/schema.py index 57a3e9b296..438a711051 100644 --- a/src/unitxt/schema.py +++ b/src/unitxt/schema.py @@ -128,6 +128,9 @@ def _prepare_media(self, instance): if isoftype(instance["media"]["images"][i], Image): instance["media"]["images"][i] = instance["media"]["images"][i]["image"] + for i in range(len(instance["media"]["audios"])): + instance["media"]["audios"][i] = instance["media"]["audios"][i]["audio"] + return instance def _get_instance_task_data( diff --git a/src/unitxt/serializers.py b/src/unitxt/serializers.py index 6d43ac9de0..be4e733b85 100644 --- a/src/unitxt/serializers.py +++ b/src/unitxt/serializers.py @@ -9,6 +9,7 @@ from .settings_utils import get_constants from .type_utils import isoftype, to_type_string from .types import ( + Audio, Dialog, Document, Image, @@ -132,6 +133,20 @@ def serialize(self, value: Image, instance: Dict[str, Any]) -> str: value["image"] = f"media/images/{idx}" return f'<{constants.image_tag} src="media/images/{idx}">' +class AudioSerializer(SingleTypeSerializer): + serialized_type = Audio + + def serialize(self, value: Audio, instance: Dict[str, Any]) -> str: + if "media" not in instance: + instance["media"] = {} + if "audios" not in instance["media"]: + instance["media"]["audios"] = [] + idx = len(instance["media"]["audios"]) + instance["media"]["audios"].append( + {"audio": value["audio"]} + ) + value["audio"] = f"media/audios/{idx}" + return f'<{constants.audio_tag} src="media/audios/{idx}">' class VideoSerializer(ImageSerializer): serialized_type = Video diff --git a/src/unitxt/settings_utils.py b/src/unitxt/settings_utils.py index 2bdc98011b..e0ded5ad27 100644 --- a/src/unitxt/settings_utils.py +++ b/src/unitxt/settings_utils.py @@ -192,6 +192,7 @@ def __getattr__(self, key): constants.inference_stream = "__INFERENCE_STREAM__" constants.instance_stream = "__INSTANCE_STREAM__" constants.image_tag = "unitxt-img" + constants.audio_tag = "unitxt-audio" constants.demos_pool_field = "_demos_pool_" constants.demos_field = "demos" constants.instruction_field = "instruction" diff --git a/src/unitxt/templates.py b/src/unitxt/templates.py index f124e0f386..4d5a3f0def 100644 --- a/src/unitxt/templates.py +++ b/src/unitxt/templates.py @@ -11,6 +11,7 @@ from .operator import InstanceOperator, Operator from .random_utils import new_random_generator from .serializers import ( + AudioSerializer, DialogSerializer, ImageSerializer, ListSerializer, @@ -60,6 +61,7 @@ class Template(InstanceOperator): serializer: Serializer = NonPositionalField( default_factory=lambda: MultiTypeSerializer( serializers=[ + AudioSerializer(), ImageSerializer(), VideoSerializer(), TableSerializer(), From b5c64c34e3c68bd9f3f6c7cb5f234f3d9c87baa8 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Mon, 21 Jul 2025 22:20:11 +0300 Subject: [PATCH 02/36] Update speech benchmark Signed-off-by: elronbandel --- examples/evaluate_speech_recognition.py | 26 +++ .../evaluate_speech_recognition_benchmark.py | 28 ++++ prepare/benchmarks/speech_recognition.py | 50 ++++++ prepare/cards/ami.py | 41 +++++ prepare/cards/gigaspeech.py | 41 +++++ prepare/cards/librispeech.py | 76 +++++++++ prepare/cards/spgispeech.py | 42 +++++ prepare/cards/tedlium.py | 41 +++++ prepare/cards/voxpopuli.py | 43 +++++ prepare/metrics/wer.py | 57 ++++++- prepare/processors/processors.py | 7 + prepare/tasks/speech_recognition.py | 22 +++ .../templates/speech_recognition/templates.py | 12 ++ src/unitxt/audio_operators.py | 22 ++- .../benchmarks/speech_recognition.json | 55 ++++++ src/unitxt/catalog/cards/ami/ihm.json | 45 +++++ src/unitxt/catalog/cards/ami/sdm.json | 45 +++++ src/unitxt/catalog/cards/gigaspeech/xs.json | 45 +++++ .../catalog/cards/librispeech/test.json | 43 +++++ .../catalog/cards/librispeech/test_clean.json | 43 +++++ src/unitxt/catalog/cards/spgispeech/s.json | 50 ++++++ .../catalog/cards/tedlium/release1.json | 45 +++++ .../catalog/cards/tedlium/release2.json | 45 +++++ .../catalog/cards/tedlium/release3.json | 45 +++++ src/unitxt/catalog/cards/voxpopuli/en.json | 51 ++++++ src/unitxt/catalog/metrics/wer.json | 2 +- .../normalize_text_with_whisper.json | 3 + .../normalize_text_with_whisper.json | 8 + .../catalog/tasks/speech_recognition.json | 21 +++ .../templates/speech_recognition/default.json | 8 + src/unitxt/formats.py | 4 +- src/unitxt/inference.py | 156 +++++++++++++++++- src/unitxt/metrics.py | 68 ++++++++ src/unitxt/processors.py | 16 ++ src/unitxt/serializers.py | 1 + tests/inference/test_inference_engine.py | 65 ++++++++ 36 files changed, 1355 insertions(+), 17 deletions(-) create mode 100644 examples/evaluate_speech_recognition.py create mode 100644 examples/evaluate_speech_recognition_benchmark.py create mode 100644 prepare/benchmarks/speech_recognition.py create mode 100644 prepare/cards/ami.py create mode 100644 prepare/cards/gigaspeech.py create mode 100644 prepare/cards/librispeech.py create mode 100644 prepare/cards/spgispeech.py create mode 100644 prepare/cards/tedlium.py create mode 100644 prepare/cards/voxpopuli.py create mode 100644 prepare/tasks/speech_recognition.py create mode 100644 prepare/templates/speech_recognition/templates.py create mode 100644 src/unitxt/catalog/benchmarks/speech_recognition.json create mode 100644 src/unitxt/catalog/cards/ami/ihm.json create mode 100644 src/unitxt/catalog/cards/ami/sdm.json create mode 100644 src/unitxt/catalog/cards/gigaspeech/xs.json create mode 100644 src/unitxt/catalog/cards/librispeech/test.json create mode 100644 src/unitxt/catalog/cards/librispeech/test_clean.json create mode 100644 src/unitxt/catalog/cards/spgispeech/s.json create mode 100644 src/unitxt/catalog/cards/tedlium/release1.json create mode 100644 src/unitxt/catalog/cards/tedlium/release2.json create mode 100644 src/unitxt/catalog/cards/tedlium/release3.json create mode 100644 src/unitxt/catalog/cards/voxpopuli/en.json create mode 100644 src/unitxt/catalog/operators/normalize_text_with_whisper.json create mode 100644 src/unitxt/catalog/processors/normalize_text_with_whisper.json create mode 100644 src/unitxt/catalog/tasks/speech_recognition.json create mode 100644 src/unitxt/catalog/templates/speech_recognition/default.json diff --git a/examples/evaluate_speech_recognition.py b/examples/evaluate_speech_recognition.py new file mode 100644 index 0000000000..9c648214c2 --- /dev/null +++ b/examples/evaluate_speech_recognition.py @@ -0,0 +1,26 @@ +from unitxt import evaluate, load_dataset +from unitxt.inference import ( + HFGraniteSpeechInferenceEngine, +) +from unitxt.system_prompts import TextualSystemPrompt + +test_dataset = load_dataset( + card="cards.ami", + split="test", + format="formats.chat_api", + max_test_instances=10, + system_prompt=TextualSystemPrompt( + text="Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant" + ), +) + +model = HFGraniteSpeechInferenceEngine( + model_name="ibm-granite/granite-speech-3.3-2b", + max_new_tokens=200, +) + +predictions = model(test_dataset) +results = evaluate(predictions=predictions, data=test_dataset) + +print("Global scores:") +print(results.global_scores.summary) diff --git a/examples/evaluate_speech_recognition_benchmark.py b/examples/evaluate_speech_recognition_benchmark.py new file mode 100644 index 0000000000..6bc6b28507 --- /dev/null +++ b/examples/evaluate_speech_recognition_benchmark.py @@ -0,0 +1,28 @@ +from unitxt import evaluate, load_dataset +from unitxt.inference import ( + HFGraniteSpeechInferenceEngine, +) +from unitxt.system_prompts import TextualSystemPrompt + +dataset = load_dataset( + "benchmarks.speech_recognition", + max_samples_per_subset=5, + system_prompt=TextualSystemPrompt( + text="Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant" + ), + split="test", +) + +model = HFGraniteSpeechInferenceEngine( + model_name="ibm-granite/granite-speech-3.3-2b", + max_new_tokens=200, +) + +predictions = model(dataset) +results = evaluate(predictions=predictions, data=dataset) + +print("Global scores:") +print(results.global_scores.summary) + +print("Subsets scores:") +print(results.subsets_scores.summary) diff --git a/prepare/benchmarks/speech_recognition.py b/prepare/benchmarks/speech_recognition.py new file mode 100644 index 0000000000..5c4f746098 --- /dev/null +++ b/prepare/benchmarks/speech_recognition.py @@ -0,0 +1,50 @@ +from unitxt.benchmark import Benchmark +from unitxt.catalog import add_to_catalog +from unitxt.standard import DatasetRecipe + +benchmark = Benchmark( + subsets={ + "voxpopuli_en": DatasetRecipe( + card="cards.voxpopuli.en", + format="formats.chat_api", + ), + "ami_ihm": DatasetRecipe( + card="cards.ami.ihm", + format="formats.chat_api", + ), + "ami_sdm": DatasetRecipe( + card="cards.ami.sdm", + format="formats.chat_api", + ), + "gigaspeech_xs": DatasetRecipe( + card="cards.gigaspeech.xs", + format="formats.chat_api", + ), + "librispeech_test_clean": DatasetRecipe( + card="cards.librispeech.test_clean", + format="formats.chat_api", + ), + "librispeech_test": DatasetRecipe( + card="cards.librispeech.test", + format="formats.chat_api", + ), + "spgispeech_s": DatasetRecipe( + card="cards.spgispeech.s", + format="formats.chat_api", + ), + "tedlium_release1": DatasetRecipe( + card="cards.tedlium.release1", + format="formats.chat_api", + ), + "tedlium_release2": DatasetRecipe( + card="cards.tedlium.release2", + format="formats.chat_api", + ), + "tedlium_release3": DatasetRecipe( + card="cards.tedlium.release3", + format="formats.chat_api", + ), + }, +) + +add_to_catalog(benchmark, "benchmarks.speech_recognition", overwrite=True) diff --git a/prepare/cards/ami.py b/prepare/cards/ami.py new file mode 100644 index 0000000000..0f6102dbca --- /dev/null +++ b/prepare/cards/ami.py @@ -0,0 +1,41 @@ +from unitxt.audio_operators import ToAudio +from unitxt.blocks import LoadHF, TaskCard +from unitxt.catalog import add_to_catalog +from unitxt.test_utils.card import test_card + +for subset in ["ihm", "sdm"]: + card = TaskCard( + loader=LoadHF( + path="edinburghcstr/ami", + data_dir=subset, + revision="refs/convert/parquet", + splits=["train", "validation", "test"], + data_classification_policy=["public"], + streaming=True, + ), + preprocess_steps=[ + ToAudio(field="audio"), + ], + task="tasks.speech_recognition", + templates=[ + "templates.speech_recognition.default", + ], + __tags__={ + "license": "cc-by-4.0", + "language": "en", + "task_categories": ["automatic-speech-recognition"], + "size_categories": ["10K Audio: return { "audio": value, } -def audio_to_base64(audio_data): +def audio_to_base64(audio_data: Audio): """Convert a HuggingFace Audio instance to a base64-encoded WAV string. Args: @@ -25,14 +28,25 @@ def audio_to_base64(audio_data): from io import BytesIO import soundfile as sf + # Create a BytesIO buffer to hold the WAV data buffer = BytesIO() - + audio = audio_data["audio"] # Write the audio array to the buffer in WAV format - sf.write(buffer, audio_data["array"], audio_data["sampling_rate"], format="wav") + sf.write(buffer, audio["array"], audio["sampling_rate"], format="wav") # Get the bytes from the buffer wav_bytes = buffer.getvalue() # Encode to base64 return base64.b64encode(wav_bytes).decode("utf-8") + + +def base64_to_audio(base64_string, sampling_rate: Optional[int] = None): + import base64 + + from datasets import Audio + + audio_bytes = base64.b64decode(base64_string) + audio_feature = Audio(sampling_rate=sampling_rate) + return audio_feature.decode_example({"bytes": audio_bytes, "path": None}) diff --git a/src/unitxt/catalog/benchmarks/speech_recognition.json b/src/unitxt/catalog/benchmarks/speech_recognition.json new file mode 100644 index 0000000000..c7e8ae5050 --- /dev/null +++ b/src/unitxt/catalog/benchmarks/speech_recognition.json @@ -0,0 +1,55 @@ +{ + "__type__": "benchmark", + "subsets": { + "voxpopuli_en": { + "__type__": "dataset_recipe", + "card": "cards.voxpopuli.en", + "format": "formats.chat_api" + }, + "ami_ihm": { + "__type__": "dataset_recipe", + "card": "cards.ami.ihm", + "format": "formats.chat_api" + }, + "ami_sdm": { + "__type__": "dataset_recipe", + "card": "cards.ami.sdm", + "format": "formats.chat_api" + }, + "gigaspeech_xs": { + "__type__": "dataset_recipe", + "card": "cards.gigaspeech.xs", + "format": "formats.chat_api" + }, + "librispeech_test_clean": { + "__type__": "dataset_recipe", + "card": "cards.librispeech.test_clean", + "format": "formats.chat_api" + }, + "librispeech_test": { + "__type__": "dataset_recipe", + "card": "cards.librispeech.test", + "format": "formats.chat_api" + }, + "spgispeech_s": { + "__type__": "dataset_recipe", + "card": "cards.spgispeech.s", + "format": "formats.chat_api" + }, + "tedlium_release1": { + "__type__": "dataset_recipe", + "card": "cards.tedlium.release1", + "format": "formats.chat_api" + }, + "tedlium_release2": { + "__type__": "dataset_recipe", + "card": "cards.tedlium.release2", + "format": "formats.chat_api" + }, + "tedlium_release3": { + "__type__": "dataset_recipe", + "card": "cards.tedlium.release3", + "format": "formats.chat_api" + } + } +} diff --git a/src/unitxt/catalog/cards/ami/ihm.json b/src/unitxt/catalog/cards/ami/ihm.json new file mode 100644 index 0000000000..6afd19f0a7 --- /dev/null +++ b/src/unitxt/catalog/cards/ami/ihm.json @@ -0,0 +1,45 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "edinburghcstr/ami", + "data_dir": "ihm", + "revision": "refs/convert/parquet", + "splits": [ + "train", + "validation", + "test" + ], + "data_classification_policy": [ + "public" + ], + "streaming": true + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + } + ], + "task": "tasks.speech_recognition", + "templates": [ + "templates.speech_recognition.default" + ], + "__tags__": { + "license": "cc-by-4.0", + "language": "en", + "task_categories": [ + "automatic-speech-recognition" + ], + "size_categories": [ + "10K Union[str, List[Conten data_url = src contents.append( { - "type": "input_audio", - "input_audio": {"data": data_url, "format": "wav"}, + "type": "audio", + "audio": {"data": data_url, "mime_type": "audio/wav"}, } ) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index dfc8b1c15d..27d88d676b 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -36,6 +36,7 @@ from tqdm.asyncio import tqdm_asyncio from .artifact import Artifact +from .audio_operators import base64_to_audio from .base_metric import Metric from .dataclass import InternalField, NonPositionalField from .deprecation_utils import deprecation @@ -620,7 +621,7 @@ def get_logprobs( "logprob": float(score.cpu()), "top_tokens": [ { - "text": self.processor.decode(idx), + "text": self.processor.tokenizer.decode(idx), "logprob": float( predictions.scores[n][sample_no][idx].cpu() ), @@ -1042,6 +1043,159 @@ def _infer_log_probs( return self._infer_fn(dataset, return_meta_data, True) +class HFGraniteSpeechInferenceEngine(HFInferenceEngineBase): + lazy_load: bool = True + label: str = "hf_granite_speech" + audio_token: str = "<|audio|>" + sampling_rate: int = 16000 + + _requirements_list = ["torchaudio"] + + def compute_transition_scores( + self, sequences: Sequence, scores: Sequence, beam_indices: Optional[int] + ) -> Sequence: + if not hasattr(self.model.config, "vocab_size"): + try: + self.model.config.vocab_size = self.model.vocab_size + except: + self.model.config.vocab_size = self.model.config.text_config.vocab_size + + return super().compute_transition_scores(sequences, scores, beam_indices) + + def _init_processor(self): + from transformers import AutoProcessor + + self.processor = AutoProcessor.from_pretrained(self.model_name) + + if not self.pad_token_id and hasattr(self.processor, "eos_token_id"): + self.pad_token_id = self.processor.eos_token_id + + def _init_model(self): + from transformers import AutoModelForSpeechSeq2Seq + + self.model = AutoModelForSpeechSeq2Seq.from_pretrained( + self.model_name, + torch_dtype=self._get_torch_dtype(), + low_cpu_mem_usage=self.low_cpu_mem_usage, + device_map=self.device_map, + ) + if self.device_map is None: + self.model.to(self.device) + + def _get_input(self, instance): + if isinstance(instance["source"], list): + # Chat API format - extract audio from content + audios = [] + chat = [] + for turn in instance["source"]: + if isinstance(turn["content"], list): + turn_content = "" + for content in turn["content"]: + if content["type"] == "audio": + audios.append( + base64_to_audio( + content["audio"]["data"], + sampling_rate=self.sampling_rate, + ) + ) + turn_content += self.audio_token + elif content["type"] == "text": + turn_content += content["text"] + else: + raise ValueError( + f"Unsupported content type:{content['type']}" + ) + turn = {"role": turn["role"], "content": turn_content} + + chat.append(turn) + + if len(audios) > 1: + raise ValueError(f"Unsupported number of audio contents:{len(audios)}") + + audio = audios[0] + + return chat, audio + raise ValueError("Supports only chat api.") + + def prepare_inputs(self, instance) -> Mapping: + chat, audio = self._get_input(instance) + + text = self.processor.tokenizer.apply_chat_template( + chat, tokenize=False, add_generation_prompt=True + ) + + inputs: Mapping = self.processor( + text=[text], audio=audio["array"], return_tensors="pt" + ).to(self.device or self.device_map, self._get_torch_dtype()) + + return inputs + + def _infer_fn( + self, + dataset: Union[List[Dict[str, Any]], Dataset], + return_meta_data: bool, + return_logprobs: bool, + ) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]: + results = [] + + for instance in tqdm(dataset): + processed_inputs = self.prepare_inputs(instance) + input_len = len(processed_inputs["input_ids"][0]) + + predictions = self.make_predictions(processed_inputs) + + sequences = predictions.sequences + + output_tokens = sequences[:, input_len:] + + output_tokens_strings = [] + for tokens in output_tokens: + output_tokens_strings.append( + [ + self.processor.tokenizer.decode(token, skip_special_tokens=True) + for token in tokens + ] + ) + + output_strings = [] + for tokens in output_tokens: + output_strings.append( + self.processor.tokenizer.decode(tokens, skip_special_tokens=True) + ) + + if return_logprobs: + final_outputs = self.get_logprobs(predictions, output_tokens_strings) + else: + final_outputs = output_strings + + results.append( + self.get_return_object( + output=final_outputs[0], + generated_text=output_strings[0], + output_tokens=len(output_tokens_strings[0]), + inp=instance["source"], + inp_tokens=None, + return_meta_data=return_meta_data, + ) + ) + + return results + + def _infer( + self, + dataset: Union[List[Dict[str, Any]], Dataset], + return_meta_data: bool = False, + ) -> Union[List[str], List[TextGenerationInferenceOutput]]: + return self._infer_fn(dataset, return_meta_data, False) + + def _infer_log_probs( + self, + dataset: Union[List[Dict[str, Any]], Dataset], + return_meta_data: bool = False, + ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: + return self._infer_fn(dataset, return_meta_data, True) + + class HFPeftInferenceEngine(HFAutoModelInferenceEngine): label: str = "hf_peft_auto_model" diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 9460d96ca8..4302836941 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -3387,6 +3387,74 @@ def compute( return {self.main_score: result} +class WerFast(MapReduceMetric[Tuple[float, float], float]): + """Computes mean squared error between predictions and references. + + Range: [0, ∞) (lower is better) + Measures average squared differences between predicted and true values. + """ + + main_score = "wer" + prediction_type = str + single_reference_per_prediction = True + _requirements_list = ["jiwer>=3.0.0"] # added process_words function + + def prepare(self): + super().prepare() + import jiwer + + self._metric = jiwer.process_words + + def map( + self, prediction: str, references: List[str], task_data: Dict[str, Any] + ) -> Tuple[float, float]: + measures = self._metric(references[0], prediction) + incorrect = measures.substitutions + measures.deletions + measures.insertions + total = measures.substitutions + measures.deletions + measures.hits + return incorrect, total + + def reduce(self, intermediates: List[float]) -> Dict[str, Any]: + incorrect, total = map(sum, zip(*intermediates)) + return {self.main_score: incorrect / total if total > 0 else np.nan} + + +class NormalizedWer(MapReduceMetric[Tuple[float, float], float]): + """Computes mean squared error between predictions and references. + + Range: [0, ∞) (lower is better) + Measures average squared differences between predicted and true values. + """ + + main_score = "normalized_wer" + prediction_type = str + single_reference_per_prediction = True + _requirements_list = ["jiwer>=3.0.0"] # added process_words function + + def prepare(self): + super().prepare() + import jiwer + from transformers import WhisperTokenizer + + self.tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base") + + self._metric = jiwer.process_words + self._normalize = self.tokenizer.normalize + + def map( + self, prediction: str, references: List[str], task_data: Dict[str, Any] + ) -> Tuple[float, float]: + normalized_reference = self._normalize(references[0]) + normalized_prediction = self._normalize(prediction) + measures = self._metric(normalized_reference, normalized_prediction) + incorrect = measures.substitutions + measures.deletions + measures.insertions + total = measures.substitutions + measures.deletions + measures.hits + return incorrect, total + + def reduce(self, intermediates: List[float]) -> Dict[str, Any]: + incorrect, total = map(sum, zip(*intermediates)) + return {self.main_score: incorrect / total} + + class MeanSquaredError(MapReduceMetric[float, float]): """Computes mean squared error between predictions and references. diff --git a/src/unitxt/processors.py b/src/unitxt/processors.py index 6f13e10a33..ef92a9d778 100644 --- a/src/unitxt/processors.py +++ b/src/unitxt/processors.py @@ -566,3 +566,19 @@ def process_value(self, text: Any) -> Any: class ExtractVerbalJudgementBadGood(ExtractVerbalJudgment): classes = ["very bad", "bad", "mediocre", "good", "very good"] + + +class NormalizeTextWithWhisper(FieldOperator): + """A processor that uses uses whisper english normalizer.""" + + _requirements_list = ["transformers"] + + def prepare(self): + super().prepare() + from transformers import WhisperTokenizer + + self.tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base") + self._normalize = self.tokenizer.normalize + + def process_value(self, value: str) -> str: + return self._normalize(value) diff --git a/src/unitxt/serializers.py b/src/unitxt/serializers.py index 56edbdf531..0f5e0c3f05 100644 --- a/src/unitxt/serializers.py +++ b/src/unitxt/serializers.py @@ -220,6 +220,7 @@ class MultiTypeSerializer(Serializer): ToolCallSerializer(), DialogSerializer(), MultiDocumentSerializer(), + AudioSerializer(), ImageSerializer(), VideoSerializer(), TableSerializer(), diff --git a/tests/inference/test_inference_engine.py b/tests/inference/test_inference_engine.py index 0c33769f17..f4512b5a88 100644 --- a/tests/inference/test_inference_engine.py +++ b/tests/inference/test_inference_engine.py @@ -11,6 +11,7 @@ from unitxt.error_utils import UnitxtError from unitxt.inference import ( HFAutoModelInferenceEngine, + HFGraniteSpeechInferenceEngine, HFLlavaInferenceEngine, HFOptionSelectingInferenceEngine, HFPipelineBasedInferenceEngine, @@ -67,6 +68,44 @@ def get_image_dataset(format=None): ) +@lru_cache +def get_audio_dataset(format=None): + import numpy as np + + # Generate synthetic audio data (1 second of 16kHz audio) + sample_rate = 16000 + duration = 1.0 + num_samples = int(sample_rate * duration) + + # Generate synthetic audio (simple sine wave) + frequency = 440 # A4 note + time_values = np.linspace(0, duration, num_samples) + audio_data = np.sin(2 * np.pi * frequency * time_values).astype(np.float32) + + data = [ + { + "context": {"audio": {"array": audio_data, "sampling_rate": sample_rate}}, + "context_type": "audio", + "question": "What is the main topic of this audio?", + "answers": ["Music"], + }, + { + "context": {"audio": {"array": audio_data, "sampling_rate": sample_rate}}, + "context_type": "audio", + "question": "Describe the audio content", + "answers": ["Tone"], + }, + ] + + return create_dataset( + task="tasks.qa.with_context", + format=format, + test_set=data, + split="test", + data_classification_policy=["public"], + ) + + @lru_cache def get_text_dataset(format=None): instances = [ @@ -157,6 +196,32 @@ def test_llava_inference_engine(self): ["text", "logprob", "top_tokens"], ) + def test_granite_speech_inference_engine(self): + model = HFGraniteSpeechInferenceEngine( + model_name="ibm-granite/granite-speech-3.3-2b", + max_new_tokens=10, + temperature=0.0, + ) + + # Test with chat API format + dataset = get_audio_dataset(format="formats.chat_api") + + predictions = model.infer(dataset) + + # Check that we get predictions for both instances + self.assertEqual(len(predictions), 2) + self.assertIsInstance(predictions[0], str) + self.assertIsInstance(predictions[1], str) + + # Test log probabilities inference + prediction = model.infer_log_probs(dataset) + + assert isoftype(prediction, List[List[Dict[str, Any]]]) + self.assertListEqual( + list(prediction[0][0].keys()), + ["text", "logprob", "top_tokens"], + ) + def test_watsonx_inference(self): model = WMLInferenceEngineGeneration( model_name="google/flan-t5-xl", From 628d8ff65a6c5a32cddd36af5b8670b759675b13 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Tue, 22 Jul 2025 14:48:07 +0300 Subject: [PATCH 03/36] Change benchmark dataset to esb and fix imports Signed-off-by: elronbandel --- prepare/benchmarks/speech_recognition.py | 40 ++++++--------- prepare/cards/ami.py | 3 +- prepare/cards/esb/ami.py | 40 +++++++++++++++ prepare/cards/esb/earnings22.py | 40 +++++++++++++++ prepare/cards/esb/gigaspeech.py | 40 +++++++++++++++ prepare/cards/esb/librispeech.py | 42 ++++++++++++++++ prepare/cards/esb/spgispeech.py | 40 +++++++++++++++ prepare/cards/esb/tedlium.py | 40 +++++++++++++++ prepare/cards/esb/voxpopuli.py | 40 +++++++++++++++ prepare/cards/gigaspeech.py | 3 +- prepare/cards/librispeech.py | 3 +- prepare/cards/spgispeech.py | 3 +- prepare/cards/tedlium.py | 3 +- prepare/cards/voxpopuli.py | 3 +- .../catalog/cards/debug/librispeech.json | 32 ++++++++++++ src/unitxt/catalog/cards/esb/ami.json | 43 ++++++++++++++++ src/unitxt/catalog/cards/esb/earnings22.json | 43 ++++++++++++++++ src/unitxt/catalog/cards/esb/gigaspeech.json | 43 ++++++++++++++++ src/unitxt/catalog/cards/esb/librispeech.json | 49 +++++++++++++++++++ src/unitxt/catalog/cards/esb/spgispeech.json | 43 ++++++++++++++++ src/unitxt/catalog/cards/esb/tedlium.json | 43 ++++++++++++++++ src/unitxt/catalog/cards/esb/voxpopuli.json | 43 ++++++++++++++++ 22 files changed, 647 insertions(+), 32 deletions(-) create mode 100644 prepare/cards/esb/ami.py create mode 100644 prepare/cards/esb/earnings22.py create mode 100644 prepare/cards/esb/gigaspeech.py create mode 100644 prepare/cards/esb/librispeech.py create mode 100644 prepare/cards/esb/spgispeech.py create mode 100644 prepare/cards/esb/tedlium.py create mode 100644 prepare/cards/esb/voxpopuli.py create mode 100644 src/unitxt/catalog/cards/debug/librispeech.json create mode 100644 src/unitxt/catalog/cards/esb/ami.json create mode 100644 src/unitxt/catalog/cards/esb/earnings22.json create mode 100644 src/unitxt/catalog/cards/esb/gigaspeech.json create mode 100644 src/unitxt/catalog/cards/esb/librispeech.json create mode 100644 src/unitxt/catalog/cards/esb/spgispeech.json create mode 100644 src/unitxt/catalog/cards/esb/tedlium.json create mode 100644 src/unitxt/catalog/cards/esb/voxpopuli.json diff --git a/prepare/benchmarks/speech_recognition.py b/prepare/benchmarks/speech_recognition.py index 5c4f746098..fc485b17bf 100644 --- a/prepare/benchmarks/speech_recognition.py +++ b/prepare/benchmarks/speech_recognition.py @@ -4,44 +4,32 @@ benchmark = Benchmark( subsets={ - "voxpopuli_en": DatasetRecipe( - card="cards.voxpopuli.en", + "voxpopuli": DatasetRecipe( + card="cards.esb.voxpopuli", format="formats.chat_api", ), - "ami_ihm": DatasetRecipe( - card="cards.ami.ihm", + "ami": DatasetRecipe( + card="cards.esb.ami", format="formats.chat_api", ), - "ami_sdm": DatasetRecipe( - card="cards.ami.sdm", + "gigaspeech": DatasetRecipe( + card="cards.esb.gigaspeech", format="formats.chat_api", ), - "gigaspeech_xs": DatasetRecipe( - card="cards.gigaspeech.xs", + "librispeech": DatasetRecipe( + card="cards.esb.librispeech", format="formats.chat_api", ), - "librispeech_test_clean": DatasetRecipe( - card="cards.librispeech.test_clean", + "spgispeech": DatasetRecipe( + card="cards.esb.spgispeech", format="formats.chat_api", ), - "librispeech_test": DatasetRecipe( - card="cards.librispeech.test", + "tedlium": DatasetRecipe( + card="cards.esb.tedlium", format="formats.chat_api", ), - "spgispeech_s": DatasetRecipe( - card="cards.spgispeech.s", - format="formats.chat_api", - ), - "tedlium_release1": DatasetRecipe( - card="cards.tedlium.release1", - format="formats.chat_api", - ), - "tedlium_release2": DatasetRecipe( - card="cards.tedlium.release2", - format="formats.chat_api", - ), - "tedlium_release3": DatasetRecipe( - card="cards.tedlium.release3", + "earnings22": DatasetRecipe( + card="cards.esb.earnings22", format="formats.chat_api", ), }, diff --git a/prepare/cards/ami.py b/prepare/cards/ami.py index 0f6102dbca..ada958eb3e 100644 --- a/prepare/cards/ami.py +++ b/prepare/cards/ami.py @@ -1,6 +1,7 @@ from unitxt.audio_operators import ToAudio -from unitxt.blocks import LoadHF, TaskCard +from unitxt.card import TaskCard from unitxt.catalog import add_to_catalog +from unitxt.loaders import LoadHF from unitxt.test_utils.card import test_card for subset in ["ihm", "sdm"]: diff --git a/prepare/cards/esb/ami.py b/prepare/cards/esb/ami.py new file mode 100644 index 0000000000..565bcea6be --- /dev/null +++ b/prepare/cards/esb/ami.py @@ -0,0 +1,40 @@ +from unitxt.audio_operators import ToAudio +from unitxt.card import TaskCard +from unitxt.catalog import add_to_catalog +from unitxt.loaders import LoadHF +from unitxt.test_utils.card import test_card + +card = TaskCard( + loader=LoadHF( + path="hf-audio/esb-datasets-test-only-sorted", + name="ami", + splits=["test"], + data_classification_policy=["public"], + streaming=True, + ), + preprocess_steps=[ + ToAudio(field="audio"), + ], + task="tasks.speech_recognition", + templates=[ + "templates.speech_recognition.default", + ], + __tags__={ + "license": "cc-by-4.0", + "language": "en", + "task_categories": ["automatic-speech-recognition"], + "size_categories": ["10K Date: Thu, 24 Jul 2025 12:33:12 +0300 Subject: [PATCH 04/36] Some updates Signed-off-by: elronbandel --- examples/evaluate_speech_recognition.py | 2 +- .../evaluate_speech_recognition_benchmark.py | 1 + .../benchmarks/speech_recognition.json | 43 ++++++------------- src/unitxt/metric_utils.py | 2 +- 4 files changed, 17 insertions(+), 31 deletions(-) diff --git a/examples/evaluate_speech_recognition.py b/examples/evaluate_speech_recognition.py index 9c648214c2..a4a3905abe 100644 --- a/examples/evaluate_speech_recognition.py +++ b/examples/evaluate_speech_recognition.py @@ -5,7 +5,7 @@ from unitxt.system_prompts import TextualSystemPrompt test_dataset = load_dataset( - card="cards.ami", + card="cards.esb.ami", split="test", format="formats.chat_api", max_test_instances=10, diff --git a/examples/evaluate_speech_recognition_benchmark.py b/examples/evaluate_speech_recognition_benchmark.py index 6bc6b28507..ee7bf05b15 100644 --- a/examples/evaluate_speech_recognition_benchmark.py +++ b/examples/evaluate_speech_recognition_benchmark.py @@ -7,6 +7,7 @@ dataset = load_dataset( "benchmarks.speech_recognition", max_samples_per_subset=5, + subset="ami", system_prompt=TextualSystemPrompt( text="Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant" ), diff --git a/src/unitxt/catalog/benchmarks/speech_recognition.json b/src/unitxt/catalog/benchmarks/speech_recognition.json index c7e8ae5050..6c9ed9eac6 100644 --- a/src/unitxt/catalog/benchmarks/speech_recognition.json +++ b/src/unitxt/catalog/benchmarks/speech_recognition.json @@ -1,54 +1,39 @@ { "__type__": "benchmark", "subsets": { - "voxpopuli_en": { + "voxpopuli": { "__type__": "dataset_recipe", - "card": "cards.voxpopuli.en", + "card": "cards.esb.voxpopuli", "format": "formats.chat_api" }, - "ami_ihm": { + "ami": { "__type__": "dataset_recipe", - "card": "cards.ami.ihm", + "card": "cards.esb.ami", "format": "formats.chat_api" }, - "ami_sdm": { + "gigaspeech": { "__type__": "dataset_recipe", - "card": "cards.ami.sdm", + "card": "cards.esb.gigaspeech", "format": "formats.chat_api" }, - "gigaspeech_xs": { + "librispeech": { "__type__": "dataset_recipe", - "card": "cards.gigaspeech.xs", + "card": "cards.esb.librispeech", "format": "formats.chat_api" }, - "librispeech_test_clean": { + "spgispeech": { "__type__": "dataset_recipe", - "card": "cards.librispeech.test_clean", + "card": "cards.esb.spgispeech", "format": "formats.chat_api" }, - "librispeech_test": { + "tedlium": { "__type__": "dataset_recipe", - "card": "cards.librispeech.test", + "card": "cards.esb.tedlium", "format": "formats.chat_api" }, - "spgispeech_s": { + "earnings22": { "__type__": "dataset_recipe", - "card": "cards.spgispeech.s", - "format": "formats.chat_api" - }, - "tedlium_release1": { - "__type__": "dataset_recipe", - "card": "cards.tedlium.release1", - "format": "formats.chat_api" - }, - "tedlium_release2": { - "__type__": "dataset_recipe", - "card": "cards.tedlium.release2", - "format": "formats.chat_api" - }, - "tedlium_release3": { - "__type__": "dataset_recipe", - "card": "cards.tedlium.release3", + "card": "cards.esb.earnings22", "format": "formats.chat_api" } } diff --git a/src/unitxt/metric_utils.py b/src/unitxt/metric_utils.py index 3bfc41094e..c41ea125b3 100644 --- a/src/unitxt/metric_utils.py +++ b/src/unitxt/metric_utils.py @@ -436,7 +436,7 @@ def __repr__(self): @property def summary(self): - df = self.to_df().round(2).fillna("") + df = self.to_df().round(4).fillna("") df = df.sort_index() df = df.drop("num_of_instances", axis=0) df = df.reset_index() From 1e8fd9efd07a56c30a20ccde24e0d1cc8205e5f1 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Thu, 24 Jul 2025 12:34:02 +0300 Subject: [PATCH 05/36] Comment out subset in benchmark Signed-off-by: elronbandel --- examples/evaluate_speech_recognition_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/evaluate_speech_recognition_benchmark.py b/examples/evaluate_speech_recognition_benchmark.py index ee7bf05b15..c08b571223 100644 --- a/examples/evaluate_speech_recognition_benchmark.py +++ b/examples/evaluate_speech_recognition_benchmark.py @@ -7,7 +7,7 @@ dataset = load_dataset( "benchmarks.speech_recognition", max_samples_per_subset=5, - subset="ami", + # subset="ami", system_prompt=TextualSystemPrompt( text="Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant" ), From 029019e219a13a63b502c1905458102bd7fa5bb8 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Thu, 24 Jul 2025 13:20:42 +0300 Subject: [PATCH 06/36] Add fluers Signed-off-by: elronbandel --- prepare/cards/fleures.py | 145 ++++++++++++++++++ prepare/tasks/translation/speech.py | 17 ++ prepare/templates/translation/speech.py | 11 ++ src/unitxt/catalog/cards/fleurs/de_de.json | 34 ++++ src/unitxt/catalog/cards/fleurs/en_us.json | 34 ++++ src/unitxt/catalog/cards/fleurs/es_419.json | 34 ++++ src/unitxt/catalog/cards/fleurs/fr_fr.json | 34 ++++ src/unitxt/catalog/cards/fleurs/it_it.json | 34 ++++ src/unitxt/catalog/cards/fleurs/ja_jp.json | 34 ++++ src/unitxt/catalog/cards/fleurs/pt_br.json | 34 ++++ .../catalog/tasks/translation/speech.json | 14 ++ .../templates/translation/speech/default.json | 5 + 12 files changed, 430 insertions(+) create mode 100644 prepare/cards/fleures.py create mode 100644 prepare/tasks/translation/speech.py create mode 100644 prepare/templates/translation/speech.py create mode 100644 src/unitxt/catalog/cards/fleurs/de_de.json create mode 100644 src/unitxt/catalog/cards/fleurs/en_us.json create mode 100644 src/unitxt/catalog/cards/fleurs/es_419.json create mode 100644 src/unitxt/catalog/cards/fleurs/fr_fr.json create mode 100644 src/unitxt/catalog/cards/fleurs/it_it.json create mode 100644 src/unitxt/catalog/cards/fleurs/ja_jp.json create mode 100644 src/unitxt/catalog/cards/fleurs/pt_br.json create mode 100644 src/unitxt/catalog/tasks/translation/speech.json create mode 100644 src/unitxt/catalog/templates/translation/speech/default.json diff --git a/prepare/cards/fleures.py b/prepare/cards/fleures.py new file mode 100644 index 0000000000..5fc743515a --- /dev/null +++ b/prepare/cards/fleures.py @@ -0,0 +1,145 @@ +from unitxt.audio_operators import ToAudio +from unitxt.blocks import LoadHF, TaskCard +from unitxt.catalog import add_to_catalog +from unitxt.operators import Rename +from unitxt.test_utils.card import test_card + +all_subsets = [ + "af_za", + "am_et", + "ar_eg", + "as_in", + "ast_es", + "az_az", + "be_by", + "bg_bg", + "bn_in", + "bs_ba", + "ca_es", + "ceb_ph", + "ckb_iq", + "cmn_hans_cn", + "cs_cz", + "cy_gb", + "da_dk", + "de_de", + "el_gr", + "en_us", + "es_419", + "et_ee", + "fa_ir", + "ff_sn", + "fi_fi", + "fil_ph", + "fr_fr", + "ga_ie", + "gl_es", + "gu_in", + "ha_ng", + "he_il", + "hi_in", + "hr_hr", + "hu_hu", + "hy_am", + "id_id", + "ig_ng", + "is_is", + "it_it", + "ja_jp", + "jv_id", + "ka_ge", + "kam_ke", + "kea_cv", + "kk_kz", + "km_kh", + "kn_in", + "ko_kr", + "ky_kg", + "lb_lu", + "lg_ug", + "ln_cd", + "lo_la", + "lt_lt", + "luo_ke", + "lv_lv", + "mi_nz", + "mk_mk", + "ml_in", + "mn_mn", + "mr_in", + "ms_my", + "mt_mt", + "my_mm", + "nb_no", + "ne_np", + "nl_nl", + "nso_za", + "ny_mw", + "oc_fr", + "om_et", + "or_in", + "pa_in", + "pl_pl", + "ps_af", + "pt_br", + "ro_ro", + "ru_ru", + "sd_in", + "sk_sk", + "sl_si", + "sn_zw", + "so_so", + "sr_rs", + "sv_se", + "sw_ke", + "ta_in", + "te_in", + "tg_tj", + "th_th", + "tr_tr", + "uk_ua", + "umb_ao", + "ur_pk", + "uz_uz", + "vi_vn", + "wo_sn", + "xh_za", + "yo_ng", + "yue_hant_hk", + "zu_za", +] + +target_subsets = [ # langs currently supported by sacrebleu + "de", + "es", + "fr", + "it", + "ja", + "pt", + "zh", + "en", +] + +subsets = [subset for subset in all_subsets if subset[:2] in target_subsets] + +for subset in subsets: + card = TaskCard( + loader=LoadHF( + path="google/fleurs", + revision="refs/convert/parquet", + data_dir=subset, + splits=["train", "validation", "test"], + ), + preprocess_steps=[ + ToAudio(field="audio"), + Rename(field="transcription", to_field="translation"), + Rename(field="language", to_field="target_language"), + ], + task="tasks.translation.speech", + templates=[ + "templates.translation.speech.default", + ], + ) + if subset == subsets[0]: + test_card(card, demos_taken_from="test", num_demos=0) + add_to_catalog(card, f"cards.fleurs.{subset}", overwrite=True) diff --git a/prepare/tasks/translation/speech.py b/prepare/tasks/translation/speech.py new file mode 100644 index 0000000000..81e298f569 --- /dev/null +++ b/prepare/tasks/translation/speech.py @@ -0,0 +1,17 @@ +from unitxt.blocks import Task +from unitxt.catalog import add_to_catalog +from unitxt.types import Audio + +add_to_catalog( + Task( + input_fields={ + "audio": Audio, + "target_language": str, + }, + reference_fields={"translation": str}, + prediction_type=str, + metrics=["metrics.normalized_sacrebleu"], + ), + "tasks.translation.speech", + overwrite=True, +) diff --git a/prepare/templates/translation/speech.py b/prepare/templates/translation/speech.py new file mode 100644 index 0000000000..a4d5cdb7ce --- /dev/null +++ b/prepare/templates/translation/speech.py @@ -0,0 +1,11 @@ +from unitxt.catalog import add_to_catalog +from unitxt.templates import InputOutputTemplate + +add_to_catalog( + InputOutputTemplate( + input_format="{audio}listen to the speech and translate it to {target_language}", + output_format="{translation}", + ), + "templates.translation.speech.default", + overwrite=True, +) diff --git a/src/unitxt/catalog/cards/fleurs/de_de.json b/src/unitxt/catalog/cards/fleurs/de_de.json new file mode 100644 index 0000000000..7076f896da --- /dev/null +++ b/src/unitxt/catalog/cards/fleurs/de_de.json @@ -0,0 +1,34 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "de_de", + "splits": [ + "train", + "validation", + "test" + ] + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "transcription", + "to_field": "translation" + }, + { + "__type__": "rename", + "field": "language", + "to_field": "target_language" + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/fleurs/en_us.json b/src/unitxt/catalog/cards/fleurs/en_us.json new file mode 100644 index 0000000000..c0820eccb5 --- /dev/null +++ b/src/unitxt/catalog/cards/fleurs/en_us.json @@ -0,0 +1,34 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "en_us", + "splits": [ + "train", + "validation", + "test" + ] + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "transcription", + "to_field": "translation" + }, + { + "__type__": "rename", + "field": "language", + "to_field": "target_language" + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/fleurs/es_419.json b/src/unitxt/catalog/cards/fleurs/es_419.json new file mode 100644 index 0000000000..13f79aed3e --- /dev/null +++ b/src/unitxt/catalog/cards/fleurs/es_419.json @@ -0,0 +1,34 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "es_419", + "splits": [ + "train", + "validation", + "test" + ] + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "transcription", + "to_field": "translation" + }, + { + "__type__": "rename", + "field": "language", + "to_field": "target_language" + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/fleurs/fr_fr.json b/src/unitxt/catalog/cards/fleurs/fr_fr.json new file mode 100644 index 0000000000..7720b70abc --- /dev/null +++ b/src/unitxt/catalog/cards/fleurs/fr_fr.json @@ -0,0 +1,34 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "fr_fr", + "splits": [ + "train", + "validation", + "test" + ] + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "transcription", + "to_field": "translation" + }, + { + "__type__": "rename", + "field": "language", + "to_field": "target_language" + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/fleurs/it_it.json b/src/unitxt/catalog/cards/fleurs/it_it.json new file mode 100644 index 0000000000..d05c2b4636 --- /dev/null +++ b/src/unitxt/catalog/cards/fleurs/it_it.json @@ -0,0 +1,34 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "it_it", + "splits": [ + "train", + "validation", + "test" + ] + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "transcription", + "to_field": "translation" + }, + { + "__type__": "rename", + "field": "language", + "to_field": "target_language" + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/fleurs/ja_jp.json b/src/unitxt/catalog/cards/fleurs/ja_jp.json new file mode 100644 index 0000000000..a8e6929913 --- /dev/null +++ b/src/unitxt/catalog/cards/fleurs/ja_jp.json @@ -0,0 +1,34 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "ja_jp", + "splits": [ + "train", + "validation", + "test" + ] + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "transcription", + "to_field": "translation" + }, + { + "__type__": "rename", + "field": "language", + "to_field": "target_language" + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/fleurs/pt_br.json b/src/unitxt/catalog/cards/fleurs/pt_br.json new file mode 100644 index 0000000000..4172f390b4 --- /dev/null +++ b/src/unitxt/catalog/cards/fleurs/pt_br.json @@ -0,0 +1,34 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "pt_br", + "splits": [ + "train", + "validation", + "test" + ] + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "transcription", + "to_field": "translation" + }, + { + "__type__": "rename", + "field": "language", + "to_field": "target_language" + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/tasks/translation/speech.json b/src/unitxt/catalog/tasks/translation/speech.json new file mode 100644 index 0000000000..23153d9366 --- /dev/null +++ b/src/unitxt/catalog/tasks/translation/speech.json @@ -0,0 +1,14 @@ +{ + "__type__": "task", + "input_fields": { + "audio": "Audio", + "target_language": "str" + }, + "reference_fields": { + "translation": "str" + }, + "prediction_type": "str", + "metrics": [ + "metrics.normalized_sacrebleu" + ] +} diff --git a/src/unitxt/catalog/templates/translation/speech/default.json b/src/unitxt/catalog/templates/translation/speech/default.json new file mode 100644 index 0000000000..29f080829b --- /dev/null +++ b/src/unitxt/catalog/templates/translation/speech/default.json @@ -0,0 +1,5 @@ +{ + "__type__": "input_output_template", + "input_format": "{audio}listen to the speech and translate it to {target_language}", + "output_format": "{translation}" +} From 7e5414601f933f33e0ebe09e2377eea7b0309811 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Sun, 27 Jul 2025 11:53:02 +0300 Subject: [PATCH 07/36] Add speech translation benchmark Signed-off-by: elronbandel --- examples/evaluate_speech_recognition.py | 2 +- prepare/benchmarks/speech_recognition.py | 8 +- prepare/benchmarks/speech_translation.py | 34 ++ prepare/cards/fleures.py | 327 +++++++++++------- .../benchmarks/speech_recognition.json | 5 - .../benchmarks/speech_translation.json | 35 ++ .../catalog/cards/fleurs/en_us/de_de.json | 103 ++++++ .../catalog/cards/fleurs/en_us/es_419.json | 103 ++++++ .../catalog/cards/fleurs/en_us/fr_fr.json | 103 ++++++ .../catalog/cards/fleurs/en_us/it_it.json | 103 ++++++ .../catalog/cards/fleurs/en_us/ja_jp.json | 103 ++++++ .../catalog/cards/fleurs/en_us/pt_br.json | 103 ++++++ 12 files changed, 889 insertions(+), 140 deletions(-) create mode 100644 prepare/benchmarks/speech_translation.py create mode 100644 src/unitxt/catalog/benchmarks/speech_translation.json create mode 100644 src/unitxt/catalog/cards/fleurs/en_us/de_de.json create mode 100644 src/unitxt/catalog/cards/fleurs/en_us/es_419.json create mode 100644 src/unitxt/catalog/cards/fleurs/en_us/fr_fr.json create mode 100644 src/unitxt/catalog/cards/fleurs/en_us/it_it.json create mode 100644 src/unitxt/catalog/cards/fleurs/en_us/ja_jp.json create mode 100644 src/unitxt/catalog/cards/fleurs/en_us/pt_br.json diff --git a/examples/evaluate_speech_recognition.py b/examples/evaluate_speech_recognition.py index a4a3905abe..1bd17b9432 100644 --- a/examples/evaluate_speech_recognition.py +++ b/examples/evaluate_speech_recognition.py @@ -5,7 +5,7 @@ from unitxt.system_prompts import TextualSystemPrompt test_dataset = load_dataset( - card="cards.esb.ami", + card="cards.fleurs.en_us.pt_br", split="test", format="formats.chat_api", max_test_instances=10, diff --git a/prepare/benchmarks/speech_recognition.py b/prepare/benchmarks/speech_recognition.py index fc485b17bf..404f37ee20 100644 --- a/prepare/benchmarks/speech_recognition.py +++ b/prepare/benchmarks/speech_recognition.py @@ -12,10 +12,10 @@ card="cards.esb.ami", format="formats.chat_api", ), - "gigaspeech": DatasetRecipe( - card="cards.esb.gigaspeech", - format="formats.chat_api", - ), + # "gigaspeech": DatasetRecipe( + # card="cards.esb.gigaspeech", + # format="formats.chat_api", + # ), "librispeech": DatasetRecipe( card="cards.esb.librispeech", format="formats.chat_api", diff --git a/prepare/benchmarks/speech_translation.py b/prepare/benchmarks/speech_translation.py new file mode 100644 index 0000000000..c07d8d8606 --- /dev/null +++ b/prepare/benchmarks/speech_translation.py @@ -0,0 +1,34 @@ +from unitxt.benchmark import Benchmark +from unitxt.catalog import add_to_catalog +from unitxt.standard import DatasetRecipe + +benchmark = Benchmark( + subsets={ + "en_de": DatasetRecipe( + card="cards.fleurs.en_us.de_de", + format="formats.chat_api", + ), + "en_es": DatasetRecipe( + card="cards.fleurs.en_us.es_419", + format="formats.chat_api", + ), + "en_fr": DatasetRecipe( + card="cards.fleurs.en_us.fr_fr", + format="formats.chat_api", + ), + "en_it": DatasetRecipe( + card="cards.fleurs.en_us.it_it", + format="formats.chat_api", + ), + "en_ja": DatasetRecipe( + card="cards.fleurs.en_us.ja_jp", + format="formats.chat_api", + ), + "en_pt": DatasetRecipe( + card="cards.fleurs.en_us.pt_br", + format="formats.chat_api", + ), + }, +) + +add_to_catalog(benchmark, "benchmarks.speech_translation", overwrite=True) diff --git a/prepare/cards/fleures.py b/prepare/cards/fleures.py index 5fc743515a..82271ee439 100644 --- a/prepare/cards/fleures.py +++ b/prepare/cards/fleures.py @@ -1,145 +1,212 @@ from unitxt.audio_operators import ToAudio from unitxt.blocks import LoadHF, TaskCard from unitxt.catalog import add_to_catalog -from unitxt.operators import Rename +from unitxt.loaders import MultipleSourceLoader +from unitxt.operator import SourceSequentialOperator +from unitxt.operators import RemoveFields, Rename +from unitxt.splitters import RenameSplits +from unitxt.stream_operators import JoinStreams from unitxt.test_utils.card import test_card -all_subsets = [ - "af_za", - "am_et", - "ar_eg", - "as_in", - "ast_es", - "az_az", - "be_by", - "bg_bg", - "bn_in", - "bs_ba", - "ca_es", - "ceb_ph", - "ckb_iq", - "cmn_hans_cn", - "cs_cz", - "cy_gb", - "da_dk", +# all_subsets = [ +# "af_za", +# "am_et", +# "ar_eg", +# "as_in", +# "ast_es", +# "az_az", +# "be_by", +# "bg_bg", +# "bn_in", +# "bs_ba", +# "ca_es", +# "ceb_ph", +# "ckb_iq", +# "cmn_hans_cn", +# "cs_cz", +# "cy_gb", +# "da_dk", +# "de_de", +# "el_gr", +# "en_us", +# "es_419", +# "et_ee", +# "fa_ir", +# "ff_sn", +# "fi_fi", +# "fil_ph", +# "fr_fr", +# "ga_ie", +# "gl_es", +# "gu_in", +# "ha_ng", +# "he_il", +# "hi_in", +# "hr_hr", +# "hu_hu", +# "hy_am", +# "id_id", +# "ig_ng", +# "is_is", +# "it_it", +# "ja_jp", +# "jv_id", +# "ka_ge", +# "kam_ke", +# "kea_cv", +# "kk_kz", +# "km_kh", +# "kn_in", +# "ko_kr", +# "ky_kg", +# "lb_lu", +# "lg_ug", +# "ln_cd", +# "lo_la", +# "lt_lt", +# "luo_ke", +# "lv_lv", +# "mi_nz", +# "mk_mk", +# "ml_in", +# "mn_mn", +# "mr_in", +# "ms_my", +# "mt_mt", +# "my_mm", +# "nb_no", +# "ne_np", +# "nl_nl", +# "nso_za", +# "ny_mw", +# "oc_fr", +# "om_et", +# "or_in", +# "pa_in", +# "pl_pl", +# "ps_af", +# "pt_br", +# "ro_ro", +# "ru_ru", +# "sd_in", +# "sk_sk", +# "sl_si", +# "sn_zw", +# "so_so", +# "sr_rs", +# "sv_se", +# "sw_ke", +# "ta_in", +# "te_in", +# "tg_tj", +# "th_th", +# "tr_tr", +# "uk_ua", +# "umb_ao", +# "ur_pk", +# "uz_uz", +# "vi_vn", +# "wo_sn", +# "xh_za", +# "yo_ng", +# "yue_hant_hk", +# "zu_za", +# ] + +target_subsets = [ "de_de", - "el_gr", - "en_us", "es_419", - "et_ee", - "fa_ir", - "ff_sn", - "fi_fi", - "fil_ph", "fr_fr", - "ga_ie", - "gl_es", - "gu_in", - "ha_ng", - "he_il", - "hi_in", - "hr_hr", - "hu_hu", - "hy_am", - "id_id", - "ig_ng", - "is_is", "it_it", "ja_jp", - "jv_id", - "ka_ge", - "kam_ke", - "kea_cv", - "kk_kz", - "km_kh", - "kn_in", - "ko_kr", - "ky_kg", - "lb_lu", - "lg_ug", - "ln_cd", - "lo_la", - "lt_lt", - "luo_ke", - "lv_lv", - "mi_nz", - "mk_mk", - "ml_in", - "mn_mn", - "mr_in", - "ms_my", - "mt_mt", - "my_mm", - "nb_no", - "ne_np", - "nl_nl", - "nso_za", - "ny_mw", - "oc_fr", - "om_et", - "or_in", - "pa_in", - "pl_pl", - "ps_af", "pt_br", - "ro_ro", - "ru_ru", - "sd_in", - "sk_sk", - "sl_si", - "sn_zw", - "so_so", - "sr_rs", - "sv_se", - "sw_ke", - "ta_in", - "te_in", - "tg_tj", - "th_th", - "tr_tr", - "uk_ua", - "umb_ao", - "ur_pk", - "uz_uz", - "vi_vn", - "wo_sn", - "xh_za", - "yo_ng", - "yue_hant_hk", - "zu_za", ] -target_subsets = [ # langs currently supported by sacrebleu - "de", - "es", - "fr", - "it", - "ja", - "pt", - "zh", - "en", +source_subsets = [ + "en_us", ] -subsets = [subset for subset in all_subsets if subset[:2] in target_subsets] - -for subset in subsets: - card = TaskCard( - loader=LoadHF( - path="google/fleurs", - revision="refs/convert/parquet", - data_dir=subset, - splits=["train", "validation", "test"], - ), - preprocess_steps=[ - ToAudio(field="audio"), - Rename(field="transcription", to_field="translation"), - Rename(field="language", to_field="target_language"), - ], - task="tasks.translation.speech", - templates=[ - "templates.translation.speech.default", - ], - ) - if subset == subsets[0]: - test_card(card, demos_taken_from="test", num_demos=0) - add_to_catalog(card, f"cards.fleurs.{subset}", overwrite=True) +first = True +for source_subset in source_subsets: + for target_subset in target_subsets: + card = TaskCard( + loader=MultipleSourceLoader( + sources=[ + SourceSequentialOperator( + steps=[ + LoadHF( + path="google/fleurs", + revision="refs/convert/parquet", + data_dir=source_subset, + splits=["test"], + ), + RemoveFields( + fields=[ + "num_samples", + "path", + "raw_transcription", + "transcription", + "gender", + "lang_id", + "language", + "lang_group_id", + ] + ), + RenameSplits( + { + "test": "test_input", + } + ), + ] + ), + SourceSequentialOperator( + steps=[ + LoadHF( + path="google/fleurs", + revision="refs/convert/parquet", + data_dir=target_subset, + splits=["test"], + ), + RemoveFields( + fields=[ + "num_samples", + "path", + "audio", + "raw_transcription", + "gender", + "lang_id", + "lang_group_id", + ] + ), + RenameSplits( + { + "test": "test_target", + } + ), + ] + ), + ] + ), + preprocess_steps=[ + JoinStreams( + left_stream="test_input", + right_stream="test_target", + how="inner", + on=["id"], + new_stream_name="test", + ), + ToAudio(field="audio"), + Rename(field="transcription", to_field="translation"), + Rename(field="language", to_field="target_language"), + ], + task="tasks.translation.speech", + templates=[ + "templates.translation.speech.default", + ], + ) + if first: + test_card(card, demos_taken_from="test", num_demos=0) + first = False + add_to_catalog( + card, f"cards.fleurs.{source_subset}.{target_subset}", overwrite=True + ) diff --git a/src/unitxt/catalog/benchmarks/speech_recognition.json b/src/unitxt/catalog/benchmarks/speech_recognition.json index 6c9ed9eac6..847235b566 100644 --- a/src/unitxt/catalog/benchmarks/speech_recognition.json +++ b/src/unitxt/catalog/benchmarks/speech_recognition.json @@ -11,11 +11,6 @@ "card": "cards.esb.ami", "format": "formats.chat_api" }, - "gigaspeech": { - "__type__": "dataset_recipe", - "card": "cards.esb.gigaspeech", - "format": "formats.chat_api" - }, "librispeech": { "__type__": "dataset_recipe", "card": "cards.esb.librispeech", diff --git a/src/unitxt/catalog/benchmarks/speech_translation.json b/src/unitxt/catalog/benchmarks/speech_translation.json new file mode 100644 index 0000000000..9b1db3ed4a --- /dev/null +++ b/src/unitxt/catalog/benchmarks/speech_translation.json @@ -0,0 +1,35 @@ +{ + "__type__": "benchmark", + "subsets": { + "en_de": { + "__type__": "dataset_recipe", + "card": "cards.fleurs.en_us.de_de", + "format": "formats.chat_api" + }, + "en_es": { + "__type__": "dataset_recipe", + "card": "cards.fleurs.en_us.es_419", + "format": "formats.chat_api" + }, + "en_fr": { + "__type__": "dataset_recipe", + "card": "cards.fleurs.en_us.fr_fr", + "format": "formats.chat_api" + }, + "en_it": { + "__type__": "dataset_recipe", + "card": "cards.fleurs.en_us.it_it", + "format": "formats.chat_api" + }, + "en_ja": { + "__type__": "dataset_recipe", + "card": "cards.fleurs.en_us.ja_jp", + "format": "formats.chat_api" + }, + "en_pt": { + "__type__": "dataset_recipe", + "card": "cards.fleurs.en_us.pt_br", + "format": "formats.chat_api" + } + } +} diff --git a/src/unitxt/catalog/cards/fleurs/en_us/de_de.json b/src/unitxt/catalog/cards/fleurs/en_us/de_de.json new file mode 100644 index 0000000000..8b00286a50 --- /dev/null +++ b/src/unitxt/catalog/cards/fleurs/en_us/de_de.json @@ -0,0 +1,103 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "multiple_source_loader", + "sources": [ + { + "__type__": "source_sequential_operator", + "steps": [ + { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "en_us", + "splits": [ + "test" + ] + }, + { + "__type__": "remove_fields", + "fields": [ + "num_samples", + "path", + "raw_transcription", + "transcription", + "gender", + "lang_id", + "language", + "lang_group_id" + ] + }, + { + "__type__": "rename_splits", + "mapper": { + "test": "test_input" + } + } + ] + }, + { + "__type__": "source_sequential_operator", + "steps": [ + { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "de_de", + "splits": [ + "test" + ] + }, + { + "__type__": "remove_fields", + "fields": [ + "num_samples", + "path", + "audio", + "raw_transcription", + "gender", + "lang_id", + "lang_group_id" + ] + }, + { + "__type__": "rename_splits", + "mapper": { + "test": "test_target" + } + } + ] + } + ] + }, + "preprocess_steps": [ + { + "__type__": "join_streams", + "left_stream": "test_input", + "right_stream": "test_target", + "how": "inner", + "on": [ + "id" + ], + "new_stream_name": "test" + }, + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "transcription", + "to_field": "translation" + }, + { + "__type__": "rename", + "field": "language", + "to_field": "target_language" + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/fleurs/en_us/es_419.json b/src/unitxt/catalog/cards/fleurs/en_us/es_419.json new file mode 100644 index 0000000000..3097fd8421 --- /dev/null +++ b/src/unitxt/catalog/cards/fleurs/en_us/es_419.json @@ -0,0 +1,103 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "multiple_source_loader", + "sources": [ + { + "__type__": "source_sequential_operator", + "steps": [ + { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "en_us", + "splits": [ + "test" + ] + }, + { + "__type__": "remove_fields", + "fields": [ + "num_samples", + "path", + "raw_transcription", + "transcription", + "gender", + "lang_id", + "language", + "lang_group_id" + ] + }, + { + "__type__": "rename_splits", + "mapper": { + "test": "test_input" + } + } + ] + }, + { + "__type__": "source_sequential_operator", + "steps": [ + { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "es_419", + "splits": [ + "test" + ] + }, + { + "__type__": "remove_fields", + "fields": [ + "num_samples", + "path", + "audio", + "raw_transcription", + "gender", + "lang_id", + "lang_group_id" + ] + }, + { + "__type__": "rename_splits", + "mapper": { + "test": "test_target" + } + } + ] + } + ] + }, + "preprocess_steps": [ + { + "__type__": "join_streams", + "left_stream": "test_input", + "right_stream": "test_target", + "how": "inner", + "on": [ + "id" + ], + "new_stream_name": "test" + }, + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "transcription", + "to_field": "translation" + }, + { + "__type__": "rename", + "field": "language", + "to_field": "target_language" + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/fleurs/en_us/fr_fr.json b/src/unitxt/catalog/cards/fleurs/en_us/fr_fr.json new file mode 100644 index 0000000000..503f099fe0 --- /dev/null +++ b/src/unitxt/catalog/cards/fleurs/en_us/fr_fr.json @@ -0,0 +1,103 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "multiple_source_loader", + "sources": [ + { + "__type__": "source_sequential_operator", + "steps": [ + { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "en_us", + "splits": [ + "test" + ] + }, + { + "__type__": "remove_fields", + "fields": [ + "num_samples", + "path", + "raw_transcription", + "transcription", + "gender", + "lang_id", + "language", + "lang_group_id" + ] + }, + { + "__type__": "rename_splits", + "mapper": { + "test": "test_input" + } + } + ] + }, + { + "__type__": "source_sequential_operator", + "steps": [ + { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "fr_fr", + "splits": [ + "test" + ] + }, + { + "__type__": "remove_fields", + "fields": [ + "num_samples", + "path", + "audio", + "raw_transcription", + "gender", + "lang_id", + "lang_group_id" + ] + }, + { + "__type__": "rename_splits", + "mapper": { + "test": "test_target" + } + } + ] + } + ] + }, + "preprocess_steps": [ + { + "__type__": "join_streams", + "left_stream": "test_input", + "right_stream": "test_target", + "how": "inner", + "on": [ + "id" + ], + "new_stream_name": "test" + }, + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "transcription", + "to_field": "translation" + }, + { + "__type__": "rename", + "field": "language", + "to_field": "target_language" + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/fleurs/en_us/it_it.json b/src/unitxt/catalog/cards/fleurs/en_us/it_it.json new file mode 100644 index 0000000000..fa1f545a90 --- /dev/null +++ b/src/unitxt/catalog/cards/fleurs/en_us/it_it.json @@ -0,0 +1,103 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "multiple_source_loader", + "sources": [ + { + "__type__": "source_sequential_operator", + "steps": [ + { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "en_us", + "splits": [ + "test" + ] + }, + { + "__type__": "remove_fields", + "fields": [ + "num_samples", + "path", + "raw_transcription", + "transcription", + "gender", + "lang_id", + "language", + "lang_group_id" + ] + }, + { + "__type__": "rename_splits", + "mapper": { + "test": "test_input" + } + } + ] + }, + { + "__type__": "source_sequential_operator", + "steps": [ + { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "it_it", + "splits": [ + "test" + ] + }, + { + "__type__": "remove_fields", + "fields": [ + "num_samples", + "path", + "audio", + "raw_transcription", + "gender", + "lang_id", + "lang_group_id" + ] + }, + { + "__type__": "rename_splits", + "mapper": { + "test": "test_target" + } + } + ] + } + ] + }, + "preprocess_steps": [ + { + "__type__": "join_streams", + "left_stream": "test_input", + "right_stream": "test_target", + "how": "inner", + "on": [ + "id" + ], + "new_stream_name": "test" + }, + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "transcription", + "to_field": "translation" + }, + { + "__type__": "rename", + "field": "language", + "to_field": "target_language" + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/fleurs/en_us/ja_jp.json b/src/unitxt/catalog/cards/fleurs/en_us/ja_jp.json new file mode 100644 index 0000000000..31f00a3148 --- /dev/null +++ b/src/unitxt/catalog/cards/fleurs/en_us/ja_jp.json @@ -0,0 +1,103 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "multiple_source_loader", + "sources": [ + { + "__type__": "source_sequential_operator", + "steps": [ + { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "en_us", + "splits": [ + "test" + ] + }, + { + "__type__": "remove_fields", + "fields": [ + "num_samples", + "path", + "raw_transcription", + "transcription", + "gender", + "lang_id", + "language", + "lang_group_id" + ] + }, + { + "__type__": "rename_splits", + "mapper": { + "test": "test_input" + } + } + ] + }, + { + "__type__": "source_sequential_operator", + "steps": [ + { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "ja_jp", + "splits": [ + "test" + ] + }, + { + "__type__": "remove_fields", + "fields": [ + "num_samples", + "path", + "audio", + "raw_transcription", + "gender", + "lang_id", + "lang_group_id" + ] + }, + { + "__type__": "rename_splits", + "mapper": { + "test": "test_target" + } + } + ] + } + ] + }, + "preprocess_steps": [ + { + "__type__": "join_streams", + "left_stream": "test_input", + "right_stream": "test_target", + "how": "inner", + "on": [ + "id" + ], + "new_stream_name": "test" + }, + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "transcription", + "to_field": "translation" + }, + { + "__type__": "rename", + "field": "language", + "to_field": "target_language" + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/fleurs/en_us/pt_br.json b/src/unitxt/catalog/cards/fleurs/en_us/pt_br.json new file mode 100644 index 0000000000..a63cddb177 --- /dev/null +++ b/src/unitxt/catalog/cards/fleurs/en_us/pt_br.json @@ -0,0 +1,103 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "multiple_source_loader", + "sources": [ + { + "__type__": "source_sequential_operator", + "steps": [ + { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "en_us", + "splits": [ + "test" + ] + }, + { + "__type__": "remove_fields", + "fields": [ + "num_samples", + "path", + "raw_transcription", + "transcription", + "gender", + "lang_id", + "language", + "lang_group_id" + ] + }, + { + "__type__": "rename_splits", + "mapper": { + "test": "test_input" + } + } + ] + }, + { + "__type__": "source_sequential_operator", + "steps": [ + { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "pt_br", + "splits": [ + "test" + ] + }, + { + "__type__": "remove_fields", + "fields": [ + "num_samples", + "path", + "audio", + "raw_transcription", + "gender", + "lang_id", + "lang_group_id" + ] + }, + { + "__type__": "rename_splits", + "mapper": { + "test": "test_target" + } + } + ] + } + ] + }, + "preprocess_steps": [ + { + "__type__": "join_streams", + "left_stream": "test_input", + "right_stream": "test_target", + "how": "inner", + "on": [ + "id" + ], + "new_stream_name": "test" + }, + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "transcription", + "to_field": "translation" + }, + { + "__type__": "rename", + "field": "language", + "to_field": "target_language" + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} From c7576b42601ab4c5ad2535ac69af5a8dfb6c437e Mon Sep 17 00:00:00 2001 From: elronbandel Date: Sun, 27 Jul 2025 14:57:11 +0300 Subject: [PATCH 08/36] Make sure not to aggregate data in evaluation pipeline Signed-off-by: elronbandel --- src/unitxt/metric_utils.py | 333 ++++++++++++++++++++++++----- src/unitxt/metrics.py | 10 + tests/library/test_metric_utils.py | 67 +++--- 3 files changed, 316 insertions(+), 94 deletions(-) diff --git a/src/unitxt/metric_utils.py b/src/unitxt/metric_utils.py index c41ea125b3..d104a00eca 100644 --- a/src/unitxt/metric_utils.py +++ b/src/unitxt/metric_utils.py @@ -4,13 +4,15 @@ from collections import defaultdict from functools import lru_cache from statistics import mean -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Set import pandas as pd +from datasets import Dataset as HfDataset from datasets import Features, Value +from datasets import IterableDataset as HfIterableDataset -from .dataclass import Dataclass -from .error_utils import Documentation, UnitxtError, error_context +from .dataclass import Dataclass, Field +from .error_utils import Documentation, UnitxtError from .operator import ( InstanceOperator, MultiStreamOperator, @@ -22,9 +24,7 @@ ApplyMetric, ApplyOperatorsField, ArtifactFetcherMixin, - FlattenInstances, RecursiveCopy, - Rename, ) from .register import _reset_env_local_catalogs, register_all_artifacts from .schema import UNITXT_DATASET_SCHEMA @@ -145,6 +145,19 @@ def group_str(json_str): return ",".join(f"{k}:{v}" for k, v in data.items()) +@lru_cache(maxsize=None) +def subset_stream_name(stream_name, subset, subset_depth): + return ( + stream_name + DEFAULT_STREAM_SUBSET_SEPARATOR + "/".join(subset[:subset_depth]) + ) + + +@lru_cache(maxsize=None) +def subset_group_stream_name(stream_name, subset, subset_depth, group): + subset_name = subset_stream_name(stream_name, subset, subset_depth) + return subset_name + "?" + group_str(group) + + class SplitSubsetsAndGroups(MultiStreamOperator): """Splits a MultiStream that is small - for metrics, hence: whole stream can sit in memory, split by the value of field 'group'. @@ -157,35 +170,69 @@ class SplitSubsetsAndGroups(MultiStreamOperator): subsets_depth specifies the depth of the prefix by which to split the stream. """ - subsets_field: str = "subset" - groups_field: str = "groups" + subset_groups: Dict[str, Set[str]] subset_depth: Optional[int] = None - def process(self, multi_stream: MultiStream) -> MultiStream: - result = defaultdict(list) - - for stream_name, stream in multi_stream.items(): - for i, instance in enumerate(stream): - instance["__idx__"] = i - - for field in [self.subsets_field, self.groups_field]: - if field not in instance: - raise ValueError( - f"Field {field} is missing from instance {instance}" - ) - - subset_stream_name = ( - stream_name - + DEFAULT_STREAM_SUBSET_SEPARATOR - + "/".join(instance[self.subsets_field][: self.subset_depth]) + def get_new_streams(self, stream_name): + new_streams = [] + for subset, groups in self.subset_groups.items(): + new_streams.append( + subset_stream_name( + stream_name=stream_name, + subset=subset, + subset_depth=self.subset_depth, ) + ) + for group in groups: + new_streams.append( + subset_group_stream_name( + stream_name=stream_name, + subset=subset, + subset_depth=self.subset_depth, + group=group, + ) + ) + return new_streams - result[subset_stream_name].append(instance) + def is_instance_included(self, instance, stream_name, new_stream_name) -> bool: + subset = tuple(instance.get("subset")) - for group in instance[self.groups_field]: - result[subset_stream_name + "?" + group_str(group)].append(instance) + if new_stream_name == subset_stream_name( + stream_name=stream_name, subset=subset, subset_depth=self.subset_depth + ): + return True - return MultiStream.from_iterables(result, copying=True) + for group in instance.get("groups"): + if new_stream_name == subset_group_stream_name( + stream_name=stream_name, + subset=subset, + subset_depth=self.subset_depth, + group=group, + ): + return True + + return False + + def filter_stream(self, stream, stream_name, new_stream_name): + for i, instance in enumerate(stream): + instance["__idx__"] = i + if self.is_instance_included(instance, stream_name, new_stream_name): + yield instance + + def process(self, multi_stream: MultiStream) -> MultiStream: + streams = {} + + for stream_name, stream in multi_stream.items(): + for new_stream_name in self.get_new_streams(stream_name): + streams[new_stream_name] = DynamicStream( + generator=self.filter_stream, + gen_kwargs={ + "stream": stream, + "stream_name": stream_name, + "new_stream_name": new_stream_name, + }, + ) + return MultiStream(streams) @lru_cache(maxsize=None) @@ -235,7 +282,15 @@ def process(self, multi_stream: MultiStream) -> MultiStream: idx = instance.pop("__idx__") if idx not in instances[origin]: - instances[origin][idx] = instance + instances[origin][idx] = {"score": instance["score"]} + if "processed_prediction" in instance: + instances[origin][idx]["processed_prediction"] = instance[ + "processed_prediction" + ] + if "processed_references" in instance: + instances[origin][idx]["processed_references"] = instance[ + "processed_references" + ] # from here below setting the global scores from that stream # can be done with first instance only @@ -340,36 +395,37 @@ def _inference_post_process( return [instance["processed_prediction"] for instance in multi_stream[split_name]] -class MetricRecipe(SequentialOperatorInitializer): +class PreProcessForEvaluation(SequentialOperatorInitializer): + def prepare(self): + register_all_artifacts() + self.steps = [ + FromPredictionsAndOriginalData(), + LoadJson(field="task_data"), + RecursiveCopy( + field="source", + to_field="task_data/source", + ), + ] + + +class MainEvaluationPipeline(SequentialOperator): calc_confidence_intervals: bool = True subset_depth: int = 2 + subset_groups: Dict[str, Set[str]] = Field(default_factory=dict) def prepare(self): register_all_artifacts() self.steps = [ - FromPredictionsAndOriginalData(), - LoadJson(field="task_data"), _post_process_steps, SplitSubsetsAndGroups( subset_depth=self.subset_depth, + subset_groups=self.subset_groups, ), ApplyMetric( "metrics", calc_confidence_intervals=self.calc_confidence_intervals, ), JoinSubsetsAndGroups(), - Rename( - field="raw_prediction", - to_field="prediction", - ), - Rename( - field="raw_references", - to_field="references", - ), - RecursiveCopy( - field="source", - to_field="task_data/source", - ), ] @@ -744,16 +800,139 @@ def __repr__(self): class EvaluationResults(list): - def __init__(self, *args, metadata=None, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, stream, metadata=None): + """Initialize EvaluationResults with lazy evaluation. + + Args: + stream: An iterator or generator + metadata: Optional metadata dictionary + """ + super().__init__() + self._generator = iter(stream) + self._realized = False self.metadata = metadata if metadata is not None else {} + def _realize_up_to(self, index): + """Realize elements from generator up to the given index.""" + if self._realized: + return + + current_len = super().__len__() + + # For negative indices, we need to realize everything + if index < 0: + self._realize_all() + return + + # Calculate how many more elements we need + needed = index + 1 - current_len + + # Realize the needed elements + try: + for _ in range(needed): + item = next(self._generator) + super().append(item) + except StopIteration: + self._realized = True + self._generator = None + + def _realize_all(self): + """Realize all remaining elements from the generator.""" + if self._realized: + return + + try: + while True: + item = next(self._generator) + super().append(item) + except StopIteration: + self._realized = True + self._generator = None + + def __getitem__(self, index): + if isinstance(index, slice): + # For slices, we need to realize up to the maximum index + start, stop, step = index.indices(len(self)) + if stop > 0: + self._realize_up_to(stop - 1) + else: + self._realize_all() + else: + self._realize_up_to(index) + return super().__getitem__(index) + + def __len__(self): + # If not fully realized, we need to realize everything to know the length + if not self._realized: + self._realize_all() + return super().__len__() + + def __iter__(self): + # Yield already realized elements + for i in range(super().__len__()): + yield super().__getitem__(i) + + # Continue with unrealized elements + if self._generator is not None: + try: + while True: + item = next(self._generator) + super().append(item) + yield item + except StopIteration: + self._realized = True + self._generator = None + + def append(self, item): + self._realize_all() + super().append(item) + + def extend(self, iterable): + self._realize_all() + super().extend(iterable) + + def insert(self, index, item): + self._realize_up_to(index) + super().insert(index, item) + + def remove(self, value): + self._realize_all() + super().remove(value) + + def pop(self, index=-1): + self._realize_up_to(index) + return super().pop(index) + + def index(self, value, start=0, stop=None): + if stop is None: + self._realize_all() + else: + self._realize_up_to(stop) + return super().index(value, start, stop) + + def count(self, value): + self._realize_all() + return super().count(value) + + def sort(self, key=None, reverse=False): + self._realize_all() + super().sort(key=key, reverse=reverse) + + def reverse(self): + self._realize_all() + super().reverse() + + def clear(self): + super().clear() + self._generator = None + self._realized = True + @property def global_scores(self): return GlobalScores(self[0]["score"]["global"]) @property - def instance_scores(self) -> InstanceScores: + def instance_scores(self): return InstanceScores(self) @property @@ -775,6 +954,35 @@ def subsets_scores(self): return SubsetsScores(self[0]["score"]["subsets"]) +def extract_subset_groups(dataset): + # Column-wise access types + subset_groups = {} + if not isinstance(dataset, (HfDataset, HfIterableDataset, pd.DataFrame)): + dataset = { + "subset": (item.get("subset", []) for item in dataset), + "groups": (item.get("groups", []) for item in dataset), + } + + for subset, groups in zip(dataset["subset"], dataset["groups"]): + if len(subset) == 0: + subset = "" + else: + subset = tuple(subset) + + if subset not in subset_groups: + subset_groups[subset] = set() + for group in groups: + subset_groups[subset].add(group) + + return subset_groups + + +def merge_evaluation_results(results, dataset): + for result, instance in zip(results, dataset): + instance.update(result) + yield instance + + def _compute( predictions: List[Any], references: Iterable, @@ -784,19 +992,30 @@ def _compute( ): _reset_env_local_catalogs() register_all_artifacts() - recipe = MetricRecipe(calc_confidence_intervals=calc_confidence_intervals) - with error_context(stage="Metric Processing"): - multi_stream = recipe( - predictions=predictions, references=references, split_name=split_name - ) + if not isinstance(references, (HfDataset, HfIterableDataset, pd.DataFrame, list)): + raise ValueError(f"Unsupported data type: {type(references)}") - if flatten: - operator = FlattenInstances() - multi_stream = operator(multi_stream) + subset_groups = extract_subset_groups(references) - stream = multi_stream[split_name] - return EvaluationResults(stream) + preprocess = PreProcessForEvaluation() + + preprocess_multi_stream = preprocess( + predictions=predictions, references=references, split_name=split_name + ) + + evaluate = MainEvaluationPipeline( + calc_confidence_intervals=calc_confidence_intervals, subset_groups=subset_groups + ) + + results_multi_stream = evaluate(preprocess_multi_stream) + + return EvaluationResults( + merge_evaluation_results( + results_multi_stream[split_name], + preprocess_multi_stream[split_name], + ) + ) """ diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 4302836941..62aeb94f8b 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -3528,6 +3528,16 @@ def reduce(self, intermediates: List[Tuple[float, float]]) -> Dict[str, Any]: score, p_value = self.spearmanr(a=list_a, b=list_b) + try: + score = float(score) + except: + pass + + try: + p_value = float(p_value) + except: + pass + return { self.main_score: score, "spearmanr_p_value": p_value, diff --git a/tests/library/test_metric_utils.py b/tests/library/test_metric_utils.py index c1bc844483..6f0a53defa 100644 --- a/tests/library/test_metric_utils.py +++ b/tests/library/test_metric_utils.py @@ -9,7 +9,7 @@ class TestMetricUtils(UnitxtTestCase): def test_split_none(self): - operator = SplitSubsetsAndGroups() + operator = SplitSubsetsAndGroups(subset_groups={"": {}}) ms = MultiStream.from_iterables( { @@ -61,7 +61,15 @@ def test_split_none(self): self.assertEqual({k: list(v) for k, v in result.items()}, target) def test_split_groups(self): - operator = SplitSubsetsAndGroups() + operator = SplitSubsetsAndGroups( + subset_groups={ + "": { + '{"template":"templates.t1"}', + '{"num_demos": 1}', + '{"template":"templates.t2"}', + } + } + ) ms = MultiStream.from_iterables( { @@ -143,7 +151,7 @@ def test_split_groups(self): self.assertEqual({k: list(v) for k, v in result.items()}, target) def test_split_subsets(self): - operator = SplitSubsetsAndGroups() + operator = SplitSubsetsAndGroups(subset_groups={("mnli",): {}, ("squad",): {}}) ms = MultiStream.from_iterables( { @@ -197,7 +205,16 @@ def test_split_subsets(self): self.assertEqual({k: list(v) for k, v in result.items()}, target) def test_split_subset_and_groups(self): - operator = SplitSubsetsAndGroups() + operator = SplitSubsetsAndGroups( + subset_groups={ + ("mnli",): { + '{"template":"templates.t1"}', + '{"template":"templates.t2"}', + '{"num_demos": 1}', + }, + ("squad",): {'{"template":"templates.t1"}', '{"num_demos": 1}'}, + } + ) ms = MultiStream.from_iterables( { @@ -339,9 +356,6 @@ def test_join_none(self): list(result["test"]), [ { - "subset": [], - "groups": [], - "media": {"audios": [], "images": []}, "score": { "instance": { "accuracy": 1.0, @@ -358,9 +372,6 @@ def test_join_none(self): }, }, { - "subset": [], - "groups": [], - "media": {"audios": [], "images": []}, "score": { "instance": { "accuracy": 0.0, @@ -571,8 +582,6 @@ def test_join_groups(self): list(result["test"]), [ { - "subset": [], - "groups": ['{"template":"templates.t1"}', '{"num_demos": 1}'], "score": { "instance": { "accuracy": 1.0, @@ -613,8 +622,6 @@ def test_join_groups(self): }, }, { - "subset": [], - "groups": ['{"template":"templates.t2"}', '{"num_demos": 1}'], "score": { "instance": { "accuracy": 0.0, @@ -655,8 +662,6 @@ def test_join_groups(self): }, }, { - "subset": [], - "groups": ['{"template":"templates.t1"}', '{"num_demos": 1}'], "score": { "instance": { "f1": 1.0, @@ -776,9 +781,9 @@ def test_join_subsets(self): list(result["test"]), [ { - "subset": ["mnli"], - "groups": [], - "media": {"audios": [], "images": []}, + # "subset": ["mnli"], + # "groups": [], + # "media": {"audios": [], "images": []}, "score": { "instance": { "accuracy": 1.0, @@ -810,9 +815,9 @@ def test_join_subsets(self): }, }, { - "subset": ["mnli"], - "groups": [], - "media": {"audios": [], "images": []}, + # "subset": ["mnli"], + # "groups": [], + # "media": {"audios": [], "images": []}, "score": { "instance": { "accuracy": 0.0, @@ -844,9 +849,9 @@ def test_join_subsets(self): }, }, { - "subset": ["squad"], - "groups": [], - "media": {"audios": [], "images": []}, + # "subset": ["squad"], + # "groups": [], + # "media": {"audios": [], "images": []}, "score": { "instance": { "f1": 1.0, @@ -977,9 +982,6 @@ def test_join_nested_subsets(self): list(result["test"]), [ { - "subset": ["mnli", "first"], - "groups": [], - "media": {"audios": [], "images": []}, "score": { "instance": { "accuracy": 1.0, @@ -1015,9 +1017,6 @@ def test_join_nested_subsets(self): }, }, { - "subset": ["mnli", "first"], - "groups": [], - "media": {"audios": [], "images": []}, "score": { "instance": { "accuracy": 0.0, @@ -1254,8 +1253,6 @@ def test_join_subsets_and_groups(self): list(result["test"]), [ { - "subset": ["mnli"], - "groups": ['{"template":"templates.t1"}', '{"num_demos": 1}'], "score": { "instance": { "accuracy": 1.0, @@ -1329,8 +1326,6 @@ def test_join_subsets_and_groups(self): }, }, { - "subset": ["mnli"], - "groups": ['{"template":"templates.t2"}', '{"num_demos": 1}'], "score": { "instance": { "accuracy": 0.0, @@ -1404,8 +1399,6 @@ def test_join_subsets_and_groups(self): }, }, { - "subset": ["squad"], - "groups": ['{"template":"templates.t1"}', '{"num_demos": 1}'], "score": { "instance": { "f1": 1.0, From 1efb526cef9916d49b5a3d128c328d3901adb34a Mon Sep 17 00:00:00 2001 From: elronbandel Date: Mon, 4 Aug 2025 12:56:48 +0300 Subject: [PATCH 09/36] Update catalog Signed-off-by: elronbandel --- .../catalog/cards/debug/librispeech.json | 32 ----------------- src/unitxt/catalog/cards/fleurs/de_de.json | 34 ------------------- src/unitxt/catalog/cards/fleurs/en_us.json | 34 ------------------- src/unitxt/catalog/cards/fleurs/es_419.json | 34 ------------------- src/unitxt/catalog/cards/fleurs/fr_fr.json | 34 ------------------- src/unitxt/catalog/cards/fleurs/it_it.json | 34 ------------------- src/unitxt/catalog/cards/fleurs/ja_jp.json | 34 ------------------- src/unitxt/catalog/cards/fleurs/pt_br.json | 34 ------------------- 8 files changed, 270 deletions(-) delete mode 100644 src/unitxt/catalog/cards/debug/librispeech.json delete mode 100644 src/unitxt/catalog/cards/fleurs/de_de.json delete mode 100644 src/unitxt/catalog/cards/fleurs/en_us.json delete mode 100644 src/unitxt/catalog/cards/fleurs/es_419.json delete mode 100644 src/unitxt/catalog/cards/fleurs/fr_fr.json delete mode 100644 src/unitxt/catalog/cards/fleurs/it_it.json delete mode 100644 src/unitxt/catalog/cards/fleurs/ja_jp.json delete mode 100644 src/unitxt/catalog/cards/fleurs/pt_br.json diff --git a/src/unitxt/catalog/cards/debug/librispeech.json b/src/unitxt/catalog/cards/debug/librispeech.json deleted file mode 100644 index c6095bb9dc..0000000000 --- a/src/unitxt/catalog/cards/debug/librispeech.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "__type__": "task_card", - "loader": { - "__type__": "load_hf", - "path": "hf-audio/esb-datasets-test-only-sorted", - "name": "librispeech", - "splits": [ - "test.clean" - ], - "data_classification_policy": [ - "public" - ], - "streaming": true - }, - "preprocess_steps": [ - { - "__type__": "rename_splits", - "mapper": { - "test.clean": "test" - } - }, - { - "__type__": "to_audio", - "field": "audio" - } - ], - "task": "tasks.speech_recognition", - "templates": [ - "templates.speech_recognition.default" - ], - "__title__": "Debug-LibriSpeech" -} diff --git a/src/unitxt/catalog/cards/fleurs/de_de.json b/src/unitxt/catalog/cards/fleurs/de_de.json deleted file mode 100644 index 7076f896da..0000000000 --- a/src/unitxt/catalog/cards/fleurs/de_de.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "__type__": "task_card", - "loader": { - "__type__": "load_hf", - "path": "google/fleurs", - "revision": "refs/convert/parquet", - "data_dir": "de_de", - "splits": [ - "train", - "validation", - "test" - ] - }, - "preprocess_steps": [ - { - "__type__": "to_audio", - "field": "audio" - }, - { - "__type__": "rename", - "field": "transcription", - "to_field": "translation" - }, - { - "__type__": "rename", - "field": "language", - "to_field": "target_language" - } - ], - "task": "tasks.translation.speech", - "templates": [ - "templates.translation.speech.default" - ] -} diff --git a/src/unitxt/catalog/cards/fleurs/en_us.json b/src/unitxt/catalog/cards/fleurs/en_us.json deleted file mode 100644 index c0820eccb5..0000000000 --- a/src/unitxt/catalog/cards/fleurs/en_us.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "__type__": "task_card", - "loader": { - "__type__": "load_hf", - "path": "google/fleurs", - "revision": "refs/convert/parquet", - "data_dir": "en_us", - "splits": [ - "train", - "validation", - "test" - ] - }, - "preprocess_steps": [ - { - "__type__": "to_audio", - "field": "audio" - }, - { - "__type__": "rename", - "field": "transcription", - "to_field": "translation" - }, - { - "__type__": "rename", - "field": "language", - "to_field": "target_language" - } - ], - "task": "tasks.translation.speech", - "templates": [ - "templates.translation.speech.default" - ] -} diff --git a/src/unitxt/catalog/cards/fleurs/es_419.json b/src/unitxt/catalog/cards/fleurs/es_419.json deleted file mode 100644 index 13f79aed3e..0000000000 --- a/src/unitxt/catalog/cards/fleurs/es_419.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "__type__": "task_card", - "loader": { - "__type__": "load_hf", - "path": "google/fleurs", - "revision": "refs/convert/parquet", - "data_dir": "es_419", - "splits": [ - "train", - "validation", - "test" - ] - }, - "preprocess_steps": [ - { - "__type__": "to_audio", - "field": "audio" - }, - { - "__type__": "rename", - "field": "transcription", - "to_field": "translation" - }, - { - "__type__": "rename", - "field": "language", - "to_field": "target_language" - } - ], - "task": "tasks.translation.speech", - "templates": [ - "templates.translation.speech.default" - ] -} diff --git a/src/unitxt/catalog/cards/fleurs/fr_fr.json b/src/unitxt/catalog/cards/fleurs/fr_fr.json deleted file mode 100644 index 7720b70abc..0000000000 --- a/src/unitxt/catalog/cards/fleurs/fr_fr.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "__type__": "task_card", - "loader": { - "__type__": "load_hf", - "path": "google/fleurs", - "revision": "refs/convert/parquet", - "data_dir": "fr_fr", - "splits": [ - "train", - "validation", - "test" - ] - }, - "preprocess_steps": [ - { - "__type__": "to_audio", - "field": "audio" - }, - { - "__type__": "rename", - "field": "transcription", - "to_field": "translation" - }, - { - "__type__": "rename", - "field": "language", - "to_field": "target_language" - } - ], - "task": "tasks.translation.speech", - "templates": [ - "templates.translation.speech.default" - ] -} diff --git a/src/unitxt/catalog/cards/fleurs/it_it.json b/src/unitxt/catalog/cards/fleurs/it_it.json deleted file mode 100644 index d05c2b4636..0000000000 --- a/src/unitxt/catalog/cards/fleurs/it_it.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "__type__": "task_card", - "loader": { - "__type__": "load_hf", - "path": "google/fleurs", - "revision": "refs/convert/parquet", - "data_dir": "it_it", - "splits": [ - "train", - "validation", - "test" - ] - }, - "preprocess_steps": [ - { - "__type__": "to_audio", - "field": "audio" - }, - { - "__type__": "rename", - "field": "transcription", - "to_field": "translation" - }, - { - "__type__": "rename", - "field": "language", - "to_field": "target_language" - } - ], - "task": "tasks.translation.speech", - "templates": [ - "templates.translation.speech.default" - ] -} diff --git a/src/unitxt/catalog/cards/fleurs/ja_jp.json b/src/unitxt/catalog/cards/fleurs/ja_jp.json deleted file mode 100644 index a8e6929913..0000000000 --- a/src/unitxt/catalog/cards/fleurs/ja_jp.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "__type__": "task_card", - "loader": { - "__type__": "load_hf", - "path": "google/fleurs", - "revision": "refs/convert/parquet", - "data_dir": "ja_jp", - "splits": [ - "train", - "validation", - "test" - ] - }, - "preprocess_steps": [ - { - "__type__": "to_audio", - "field": "audio" - }, - { - "__type__": "rename", - "field": "transcription", - "to_field": "translation" - }, - { - "__type__": "rename", - "field": "language", - "to_field": "target_language" - } - ], - "task": "tasks.translation.speech", - "templates": [ - "templates.translation.speech.default" - ] -} diff --git a/src/unitxt/catalog/cards/fleurs/pt_br.json b/src/unitxt/catalog/cards/fleurs/pt_br.json deleted file mode 100644 index 4172f390b4..0000000000 --- a/src/unitxt/catalog/cards/fleurs/pt_br.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "__type__": "task_card", - "loader": { - "__type__": "load_hf", - "path": "google/fleurs", - "revision": "refs/convert/parquet", - "data_dir": "pt_br", - "splits": [ - "train", - "validation", - "test" - ] - }, - "preprocess_steps": [ - { - "__type__": "to_audio", - "field": "audio" - }, - { - "__type__": "rename", - "field": "transcription", - "to_field": "translation" - }, - { - "__type__": "rename", - "field": "language", - "to_field": "target_language" - } - ], - "task": "tasks.translation.speech", - "templates": [ - "templates.translation.speech.default" - ] -} From d78b610ba1aa992edacd02c8abb6d991874cd399 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Mon, 4 Aug 2025 13:02:05 +0300 Subject: [PATCH 10/36] Update imports Signed-off-by: elronbandel --- src/unitxt/dataset.py | 1 + src/unitxt/metric.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/unitxt/dataset.py b/src/unitxt/dataset.py index 94529f42ff..9e21b71d4d 100644 --- a/src/unitxt/dataset.py +++ b/src/unitxt/dataset.py @@ -5,6 +5,7 @@ from .api import __file__ as _ from .artifact import __file__ as _ +from .audio_operators import __file__ as _ from .augmentors import __file__ as _ from .base_metric import __file__ as _ from .benchmark import __file__ as _ diff --git a/src/unitxt/metric.py b/src/unitxt/metric.py index 822340fbcd..67ad255a0e 100644 --- a/src/unitxt/metric.py +++ b/src/unitxt/metric.py @@ -4,6 +4,7 @@ from .api import __file__ as _ from .artifact import __file__ as _ +from .audio_operators import __file__ as _ from .augmentors import __file__ as _ from .base_metric import __file__ as _ from .benchmark import __file__ as _ From 4ce961021280d8263fb1fdbd179fb7de15881e49 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Mon, 4 Aug 2025 14:03:12 +0300 Subject: [PATCH 11/36] Fix test Signed-off-by: elronbandel --- tests/library/test_formats.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/library/test_formats.py b/tests/library/test_formats.py index 5f6ee3bc98..60738d6d11 100644 --- a/tests/library/test_formats.py +++ b/tests/library/test_formats.py @@ -147,9 +147,10 @@ def test_openai_format_with_images(self): "target_prefix": "The answer is ", "system_prompt": "You are a smart assistant.", "media": { + "audios": [], "images": [ {"image": create_random_jpeg_image(2, 2, 1), "format": "JPEG"} - ] + ], }, }, ] @@ -259,7 +260,7 @@ def test_openai_format_with_images(self): }, ], "demos": demo_instances, - "media": {"images": []}, + "media": {"audios": [], "images": []}, }, ] From bb0be6f89121f40f4cade94c32cf20172d209611 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Mon, 4 Aug 2025 14:04:25 +0300 Subject: [PATCH 12/36] Fix formatting minds 14 Signed-off-by: elronbandel --- prepare/cards/minds_14.py | 73 +++++++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 25 deletions(-) diff --git a/prepare/cards/minds_14.py b/prepare/cards/minds_14.py index 0b7b8c7794..a81a58d920 100644 --- a/prepare/cards/minds_14.py +++ b/prepare/cards/minds_14.py @@ -28,43 +28,66 @@ SplitRandomMix( {"train": "train[90%]", "validation": "train[5%]", "test": "train[5%]"} ), - MapInstanceValues(mappers={"intent_class": {str(i): label for i, label in enumerate(classes)}}), + MapInstanceValues( + mappers={"intent_class": {str(i): label for i, label in enumerate(classes)}} + ), Rename(field="intent_class", to_field="label"), Set( fields={ "text_type": "sentence", "type_of_class": "intent", - "classes": classes + "classes": classes, } ), ToAudio(field="audio", to_field="text"), - ], task="tasks.classification.multi_class", templates="templates.classification.multi_class.all", __tags__={ - "annotations_creators": [ - "expert-generated", - "crowdsourced", - "machine-generated" - ], - "language_creators": [ - "crowdsourced", - "expert-generated" - ], - "language": [ - "en", "fr", "it", "es", "pt", "de", "nl", "ru", "pl", "cs", "ko", "zh" - ], - "license": "cc-by-4.0", - "multilinguality": "multilingual", - "size_categories": "10K Date: Mon, 4 Aug 2025 14:35:50 +0300 Subject: [PATCH 13/36] Remove normalized wer Signed-off-by: elronbandel --- prepare/metrics/wer.py | 42 +----------------------------------------- src/unitxt/metrics.py | 37 ------------------------------------- 2 files changed, 1 insertion(+), 78 deletions(-) diff --git a/prepare/metrics/wer.py b/prepare/metrics/wer.py index 675965cb05..40f0c29273 100644 --- a/prepare/metrics/wer.py +++ b/prepare/metrics/wer.py @@ -1,5 +1,5 @@ from unitxt import add_to_catalog -from unitxt.metrics import NormalizedWer, WerFast +from unitxt.metrics import WerFast from unitxt.test_utils.metrics import test_metric metric = WerFast() @@ -40,43 +40,3 @@ ) add_to_catalog(metric, "metrics.wer", overwrite=True) - - -metric = NormalizedWer() - -predictions = ["this is the prediction", "there is an other sample"] -references = [["this is the reference"], ["there is another sample"]] - -instance_targets = [ - { - "wer": 0.25, - "score": 0.25, - "score_name": "wer", - }, # 1 errors: reokace 'prediction' with 'reference' - { - "wer": 0.5, - "score": 0.5, - "score_name": "wer", - }, # 2 errors: remove 'an' and replace 'other' with 'another' -] - -global_target = { - "num_of_instances": 2, - "score": 0.38, - "score_ci_high": 0.5, - "score_ci_low": 0.25, - "score_name": "wer", - "wer": 0.38, - "wer_ci_high": 0.5, - "wer_ci_low": 0.25, -} - -outputs = test_metric( - metric=metric, - predictions=predictions, - references=references, - instance_targets=instance_targets, - global_target=global_target, -) - -add_to_catalog(metric, "metrics.normalized_wer", overwrite=True) diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 4ae7adcf49..ce8a4b1525 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -3418,43 +3418,6 @@ def reduce(self, intermediates: List[float]) -> Dict[str, Any]: return {self.main_score: incorrect / total if total > 0 else np.nan} -class NormalizedWer(MapReduceMetric[Tuple[float, float], float]): - """Computes mean squared error between predictions and references. - - Range: [0, ∞) (lower is better) - Measures average squared differences between predicted and true values. - """ - - main_score = "normalized_wer" - prediction_type = str - single_reference_per_prediction = True - _requirements_list = ["jiwer>=3.0.0"] # added process_words function - - def prepare(self): - super().prepare() - import jiwer - from transformers import WhisperTokenizer - - self.tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base") - - self._metric = jiwer.process_words - self._normalize = self.tokenizer.normalize - - def map( - self, prediction: str, references: List[str], task_data: Dict[str, Any] - ) -> Tuple[float, float]: - normalized_reference = self._normalize(references[0]) - normalized_prediction = self._normalize(prediction) - measures = self._metric(normalized_reference, normalized_prediction) - incorrect = measures.substitutions + measures.deletions + measures.insertions - total = measures.substitutions + measures.deletions + measures.hits - return incorrect, total - - def reduce(self, intermediates: List[float]) -> Dict[str, Any]: - incorrect, total = map(sum, zip(*intermediates)) - return {self.main_score: incorrect / total} - - class MeanSquaredError(MapReduceMetric[float, float]): """Computes mean squared error between predictions and references. From 27846f65e3857c032c71091d0cc2c39ca9f54b4f Mon Sep 17 00:00:00 2001 From: elronbandel Date: Mon, 4 Aug 2025 15:10:34 +0300 Subject: [PATCH 14/36] Fix post processing pipeline for cases of eager execution Signed-off-by: elronbandel --- src/unitxt/metric_utils.py | 26 +++++++++++++++++++++++++- src/unitxt/operators.py | 8 +++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/unitxt/metric_utils.py b/src/unitxt/metric_utils.py index d104a00eca..c780ba8c8e 100644 --- a/src/unitxt/metric_utils.py +++ b/src/unitxt/metric_utils.py @@ -25,6 +25,8 @@ ApplyOperatorsField, ArtifactFetcherMixin, RecursiveCopy, + RemoveFields, + Rename, ) from .register import _reset_env_local_catalogs, register_all_artifacts from .schema import UNITXT_DATASET_SCHEMA @@ -408,6 +410,24 @@ def prepare(self): ] +class PostProcessAfterEvaluation(SequentialOperator): + steps = [ + Rename( + field="raw_prediction", + to_field="prediction", + ), + Rename( + field="raw_references", + to_field="references", + ), + RecursiveCopy( + field="source", + to_field="task_data/source", + ), + RemoveFields(fields=["__idx__"], not_exist_ok=True), + ] + + class MainEvaluationPipeline(SequentialOperator): calc_confidence_intervals: bool = True subset_depth: int = 2 @@ -1010,10 +1030,14 @@ def _compute( results_multi_stream = evaluate(preprocess_multi_stream) + post_process = PostProcessAfterEvaluation() + + data_multi_stream = post_process(preprocess_multi_stream) + return EvaluationResults( merge_evaluation_results( results_multi_stream[split_name], - preprocess_multi_stream[split_name], + data_multi_stream[split_name], ) ) diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index 1e4df404d9..7799fd4b12 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -365,15 +365,21 @@ class RemoveFields(InstanceOperator): Args: fields (List[str]): The fields to remove from each instance. + not_exist_ok (bool): If True, do not raise an error if a field does not exist. Defaults to False. """ fields: List[str] + not_exist_ok: bool = False def process( self, instance: Dict[str, Any], stream_name: Optional[str] = None ) -> Dict[str, Any]: for field_name in self.fields: - del instance[field_name] + try: + del instance[field_name] + except: + if not self.not_exist_ok: + raise return instance From 92dedffd6e06c0da628c305a8bda98da4e0be2f7 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Wed, 6 Aug 2025 11:20:16 +0300 Subject: [PATCH 15/36] Change speech format to old open ai style Signed-off-by: elronbandel --- src/unitxt/formats.py | 4 ++-- src/unitxt/inference.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/unitxt/formats.py b/src/unitxt/formats.py index 448d570b71..1b9a5ee997 100644 --- a/src/unitxt/formats.py +++ b/src/unitxt/formats.py @@ -439,8 +439,8 @@ def to_content(self, text: str, media: Dict[str, Any]) -> Union[str, List[Conten data_url = src contents.append( { - "type": "audio", - "audio": {"data": data_url, "mime_type": "audio/wav"}, + "type": "input_audio", + "input_audio": {"data": data_url, "format": "wav"}, } ) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index c88e0b50cd..89b7430d64 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -1091,10 +1091,10 @@ def _get_input(self, instance): if isinstance(turn["content"], list): turn_content = "" for content in turn["content"]: - if content["type"] == "audio": + if content["type"] == "input_audio": audios.append( base64_to_audio( - content["audio"]["data"], + content["input_audio"]["data"], sampling_rate=self.sampling_rate, ) ) From bd1743524223bc61216e1d986823c64f6f986ae3 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Wed, 6 Aug 2025 13:30:24 +0300 Subject: [PATCH 16/36] Fix evaluation postprocessing pipeline Signed-off-by: elronbandel --- src/unitxt/metric_utils.py | 4 ++++ src/unitxt/operators.py | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/unitxt/metric_utils.py b/src/unitxt/metric_utils.py index c780ba8c8e..31aebce81d 100644 --- a/src/unitxt/metric_utils.py +++ b/src/unitxt/metric_utils.py @@ -415,10 +415,14 @@ class PostProcessAfterEvaluation(SequentialOperator): Rename( field="raw_prediction", to_field="prediction", + not_exist_ok=True, + not_exist_do_nothing=True, ), Rename( field="raw_references", to_field="references", + not_exist_ok=True, + not_exist_do_nothing=True, ), RecursiveCopy( field="source", diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index 7799fd4b12..ad74036bdd 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -613,7 +613,12 @@ def process( if (not is_subpath(from_field, to_field)) and ( not is_subpath(to_field, from_field) ): - dict_delete(res, from_field, remove_empty_ancestors=True) + dict_delete( + res, + from_field, + remove_empty_ancestors=True, + not_exist_ok=self.not_exist_ok, + ) return res From 3db6a38990cbd4065b3c0ce3fddf3e2d768966f9 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Wed, 6 Aug 2025 13:39:22 +0300 Subject: [PATCH 17/36] Update tests dependencies Signed-off-by: elronbandel --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5a7db6c150..57478c0e08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,7 +108,8 @@ tests = [ "sqlparse", "diskcache", "pydantic", - "jsonschema_rs" + "jsonschema_rs", + "torchcodec", ] ui = [ "gradio", From efb00d156629544f3206eb8af9f70388b9faf507 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Wed, 6 Aug 2025 13:56:43 +0300 Subject: [PATCH 18/36] Install audio requirements Signed-off-by: elronbandel --- .github/workflows/catalog_preparation.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/catalog_preparation.yml b/.github/workflows/catalog_preparation.yml index a4e78cb414..6c25d444fb 100644 --- a/.github/workflows/catalog_preparation.yml +++ b/.github/workflows/catalog_preparation.yml @@ -37,9 +37,14 @@ jobs: python-version: '3.9' cache: 'pip' + - name: Install FFmpeg + run: | + sudo apt-get update + sudo apt-get install -y ffmpeg + - run: echo "blis==0" > constraints.txt - run: curl -LsSf https://astral.sh/uv/install.sh | sh - - run: uv pip install --upgrade --system torch --index-url https://download.pytorch.org/whl/cpu + - run: uv pip install --upgrade --system torch torchcodec --index-url https://download.pytorch.org/whl/cpu - run: uv pip install --system -c constraints.txt -e ".[tests]" - run: | pip install --only-binary :all: spacy From b80166e1d7a705b83750d342f02a1b9497846433 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Wed, 6 Aug 2025 14:12:35 +0300 Subject: [PATCH 19/36] Another try Signed-off-by: elronbandel --- .github/workflows/catalog_preparation.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/catalog_preparation.yml b/.github/workflows/catalog_preparation.yml index 6c25d444fb..e8d6d8c8e5 100644 --- a/.github/workflows/catalog_preparation.yml +++ b/.github/workflows/catalog_preparation.yml @@ -44,7 +44,7 @@ jobs: - run: echo "blis==0" > constraints.txt - run: curl -LsSf https://astral.sh/uv/install.sh | sh - - run: uv pip install --upgrade --system torch torchcodec --index-url https://download.pytorch.org/whl/cpu + - run: uv pip install --upgrade --system torch torchcodec - run: uv pip install --system -c constraints.txt -e ".[tests]" - run: | pip install --only-binary :all: spacy From 2e2e096987a622bc4cbba5cca259b47db4baba09 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Wed, 6 Aug 2025 16:20:51 +0300 Subject: [PATCH 20/36] Update tests dependencies Signed-off-by: elronbandel --- pyproject.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 57478c0e08..a5556969d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,6 +110,12 @@ tests = [ "pydantic", "jsonschema_rs", "torchcodec", + "soundfile" +] +audio = [ + "torchcodec", + "soundfile" + ] ui = [ "gradio", From 7412701175131c7e55e034eb22f952a8f2557471 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Wed, 6 Aug 2025 16:43:12 +0300 Subject: [PATCH 21/36] Update tests dependencies Signed-off-by: elronbandel --- pyproject.toml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a5556969d2..e0a818fa34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,12 +110,13 @@ tests = [ "pydantic", "jsonschema_rs", "torchcodec", - "soundfile" + "soundfile", + "librosa" ] audio = [ "torchcodec", - "soundfile" - + "soundfile", + "librosa" ] ui = [ "gradio", From ed5eaef9c597797b9393db82179fd1d4fa2d91db Mon Sep 17 00:00:00 2001 From: elronbandel Date: Wed, 6 Aug 2025 17:12:16 +0300 Subject: [PATCH 22/36] Add torch audio to inference tests dependencies Signed-off-by: elronbandel --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e0a818fa34..8ccc78c2c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,7 +136,8 @@ inference-tests = [ "tenacity", "diskcache", "numpy==1.26.4", - "ollama" + "ollama", + "torchaudio" ] assistant = [ "streamlit", From 2be030724e02d4a2b00bd7484fb4a6ccf0b7d4de Mon Sep 17 00:00:00 2001 From: elronbandel Date: Thu, 7 Aug 2025 17:06:58 +0300 Subject: [PATCH 23/36] Add torchaudio depndency Signed-off-by: elronbandel --- pyproject.toml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8ccc78c2c4..e707794521 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,12 +111,14 @@ tests = [ "jsonschema_rs", "torchcodec", "soundfile", - "librosa" + "librosa", + "torchaudio" ] audio = [ "torchcodec", "soundfile", - "librosa" + "librosa", + "torchaudio" ] ui = [ "gradio", From 1f1c5363a365ce3713465ea684ef9d2ec2680f22 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Thu, 7 Aug 2025 17:28:13 +0300 Subject: [PATCH 24/36] Update dependencies Signed-off-by: elronbandel --- utils/install.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/install.sh b/utils/install.sh index 761731d5ad..8f90d58d99 100644 --- a/utils/install.sh +++ b/utils/install.sh @@ -2,7 +2,7 @@ sudo apt-get update sudo apt-get install -y ffmpeg echo "blis==0" > constraints.txt curl -LsSf https://astral.sh/uv/install.sh | sh -uv pip install --upgrade --system torch --index-url https://download.pytorch.org/whl/cpu +uv pip install --upgrade --system torchcodec torchaudio torch --index-url https://download.pytorch.org/whl/cpu uv pip install --system -c constraints.txt -e ".[tests]" pip install --only-binary :all: spacy pip install coverage[toml] From 1d7cb81e5e17d85d3c72abe962b9e5b6b6a6b772 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Thu, 7 Aug 2025 18:02:15 +0300 Subject: [PATCH 25/36] Add reqs Signed-off-by: elronbandel --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e707794521..d21cd4c132 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,7 +112,8 @@ tests = [ "torchcodec", "soundfile", "librosa", - "torchaudio" + "torchaudio", + "protobuf" ] audio = [ "torchcodec", From 3ea315552f6f51d98cc7b31123a4fdf69d773f4b Mon Sep 17 00:00:00 2001 From: elronbandel Date: Sun, 10 Aug 2025 11:46:14 +0300 Subject: [PATCH 26/36] Try specific revision to reduce memory used Signed-off-by: elronbandel --- examples/evaluate_speech_recognition.py | 1 + examples/evaluate_speech_recognition_benchmark.py | 1 + src/unitxt/inference.py | 2 ++ tests/inference/test_inference_engine.py | 1 + 4 files changed, 5 insertions(+) diff --git a/examples/evaluate_speech_recognition.py b/examples/evaluate_speech_recognition.py index 1bd17b9432..7d5af7cac5 100644 --- a/examples/evaluate_speech_recognition.py +++ b/examples/evaluate_speech_recognition.py @@ -16,6 +16,7 @@ model = HFGraniteSpeechInferenceEngine( model_name="ibm-granite/granite-speech-3.3-2b", + revision="granite-speech-3.3.2-2b", max_new_tokens=200, ) diff --git a/examples/evaluate_speech_recognition_benchmark.py b/examples/evaluate_speech_recognition_benchmark.py index c08b571223..a7c24c6fbc 100644 --- a/examples/evaluate_speech_recognition_benchmark.py +++ b/examples/evaluate_speech_recognition_benchmark.py @@ -16,6 +16,7 @@ model = HFGraniteSpeechInferenceEngine( model_name="ibm-granite/granite-speech-3.3-2b", + revision="granite-speech-3.3.2-2b", max_new_tokens=200, ) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 1c36a60411..f81919c8c4 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -1044,6 +1044,7 @@ def _infer_log_probs( class HFGraniteSpeechInferenceEngine(HFInferenceEngineBase): + revision: str = None lazy_load: bool = True label: str = "hf_granite_speech" audio_token: str = "<|audio|>" @@ -1075,6 +1076,7 @@ def _init_model(self): self.model = AutoModelForSpeechSeq2Seq.from_pretrained( self.model_name, + revision=self.revision, torch_dtype=self._get_torch_dtype(), low_cpu_mem_usage=self.low_cpu_mem_usage, device_map=self.device_map, diff --git a/tests/inference/test_inference_engine.py b/tests/inference/test_inference_engine.py index a81754da05..07c23e67de 100644 --- a/tests/inference/test_inference_engine.py +++ b/tests/inference/test_inference_engine.py @@ -199,6 +199,7 @@ def test_llava_inference_engine(self): def test_granite_speech_inference_engine(self): model = HFGraniteSpeechInferenceEngine( model_name="ibm-granite/granite-speech-3.3-2b", + revision="granite-speech-3.3.2-2b", max_new_tokens=10, temperature=0.0, ) From 95f6f755cf060bc42952027c226533e925e01145 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Wed, 13 Aug 2025 13:52:33 +0300 Subject: [PATCH 27/36] Allow skipping tests Signed-off-by: elronbandel --- .github/workflows/examples_tests.yml | 1 + .github/workflows/inference_tests.yml | 2 ++ examples/evaluate_speech_recognition.py | 5 +++++ examples/evaluate_speech_recognition_benchmark.py | 6 ++++++ tests/inference/test_inference_engine.py | 3 +++ 5 files changed, 17 insertions(+) diff --git a/.github/workflows/examples_tests.yml b/.github/workflows/examples_tests.yml index 14a3c1f23b..ead3f306bb 100644 --- a/.github/workflows/examples_tests.yml +++ b/.github/workflows/examples_tests.yml @@ -30,6 +30,7 @@ jobs: WML_PROJECT_ID: ${{ secrets.WML_PROJECT_ID }} WML_APIKEY: ${{ secrets.WML_APIKEY }} GENAI_KEY: ${{ secrets.GENAI_KEY }} + SKIP_HEAVY_LOCAL: "True" steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/inference_tests.yml b/.github/workflows/inference_tests.yml index 2407bdf1ba..014926cfe4 100644 --- a/.github/workflows/inference_tests.yml +++ b/.github/workflows/inference_tests.yml @@ -32,6 +32,8 @@ jobs: WX_PROJECT_ID: ${{ secrets.WML_PROJECT_ID }} # Similar to WML_PROJECT_ID WX_API_KEY: ${{ secrets.WML_APIKEY }} # Similar to WML_APIKEY GENAI_KEY: ${{ secrets.GENAI_KEY }} + SKIP_HEAVY_LOCAL: "True" + steps: - uses: actions/checkout@v4 diff --git a/examples/evaluate_speech_recognition.py b/examples/evaluate_speech_recognition.py index 7d5af7cac5..b6c33310e5 100644 --- a/examples/evaluate_speech_recognition.py +++ b/examples/evaluate_speech_recognition.py @@ -1,3 +1,5 @@ +import os + from unitxt import evaluate, load_dataset from unitxt.inference import ( HFGraniteSpeechInferenceEngine, @@ -14,6 +16,9 @@ ), ) +if os.environ.get("SKIP_HEAVY_LOCAL", False): + exit() + model = HFGraniteSpeechInferenceEngine( model_name="ibm-granite/granite-speech-3.3-2b", revision="granite-speech-3.3.2-2b", diff --git a/examples/evaluate_speech_recognition_benchmark.py b/examples/evaluate_speech_recognition_benchmark.py index a7c24c6fbc..24c6d6d3f8 100644 --- a/examples/evaluate_speech_recognition_benchmark.py +++ b/examples/evaluate_speech_recognition_benchmark.py @@ -1,3 +1,5 @@ +import os + from unitxt import evaluate, load_dataset from unitxt.inference import ( HFGraniteSpeechInferenceEngine, @@ -14,6 +16,10 @@ split="test", ) + +if os.environ.get("SKIP_HEAVY_LOCAL", False): + exit() + model = HFGraniteSpeechInferenceEngine( model_name="ibm-granite/granite-speech-3.3-2b", revision="granite-speech-3.3.2-2b", diff --git a/tests/inference/test_inference_engine.py b/tests/inference/test_inference_engine.py index 9d4c547753..adce996e33 100644 --- a/tests/inference/test_inference_engine.py +++ b/tests/inference/test_inference_engine.py @@ -197,6 +197,9 @@ def test_llava_inference_engine(self): ) def test_granite_speech_inference_engine(self): + if os.environ.get("SKIP_HEAVY_LOCAL"): + return + model = HFGraniteSpeechInferenceEngine( model_name="ibm-granite/granite-speech-3.3-2b", revision="granite-speech-3.3.2-2b", From c2f9962930f6f85e0b92b3a6fb1f30b961897309 Mon Sep 17 00:00:00 2001 From: aharonsatt Date: Thu, 14 Aug 2025 11:42:38 +0300 Subject: [PATCH 28/36] Added support for speech recognition and speech translation (#1919) Signed-off-by: AHARON SATT AHARONSA@il.ibm.com Signed-off-by: elronbandel Co-authored-by: AHARON SATT AHARONSA@il.ibm.com Co-authored-by: elronbandel --- examples/evaluate_speech_recognition.py | 47 ++++-- .../evaluate_speech_recognition_benchmark.py | 17 ++- .../evaluate_speech_translation_benchmark.py | 38 +++++ .../evaluate_speech_translation_covost2.py | 58 +++++++ .../evaluate_speech_translation_fleurs.py | 58 +++++++ prepare/benchmarks/speech_recognition.py | 24 ++- prepare/benchmarks/speech_translation.py | 43 +++++- prepare/cards/commonvoice.py | 50 ++++++ prepare/cards/covost2.py | 142 ++++++++++++++++++ prepare/cards/{fleures.py => fleurs.py} | 55 ++++++- prepare/metrics/normalized_sacrebleu.py | 4 + prepare/processors/processors.py | 7 + prepare/tasks/speech_recognition.py | 18 +++ prepare/tasks/translation/speech.py | 7 + .../templates/speech_recognition/templates.py | 10 ++ prepare/templates/translation/speech.py | 15 +- .../benchmarks/speech_recognition.json | 25 +++ .../benchmarks/speech_translation.json | 47 +++++- src/unitxt/catalog/cards/commonvoice/de.json | 35 +++++ src/unitxt/catalog/cards/commonvoice/en.json | 35 +++++ src/unitxt/catalog/cards/commonvoice/es.json | 35 +++++ src/unitxt/catalog/cards/commonvoice/fr.json | 35 +++++ src/unitxt/catalog/cards/commonvoice/pt.json | 35 +++++ .../catalog/cards/covost2/from_en/en_de.json | 28 ++++ .../catalog/cards/covost2/from_en/en_ja.json | 28 ++++ .../catalog/cards/covost2/to_en/de_en.json | 28 ++++ .../catalog/cards/covost2/to_en/es_en.json | 28 ++++ .../catalog/cards/covost2/to_en/fr_en.json | 28 ++++ .../catalog/cards/covost2/to_en/pt_en.json | 28 ++++ .../cards/fleurs/en_us/cmn_hans_cn.json | 112 ++++++++++++++ .../catalog/cards/fleurs/en_us/de_de.json | 17 ++- .../catalog/cards/fleurs/en_us/es_419.json | 17 ++- .../catalog/cards/fleurs/en_us/fr_fr.json | 17 ++- .../catalog/cards/fleurs/en_us/it_it.json | 17 ++- .../catalog/cards/fleurs/en_us/ja_jp.json | 17 ++- .../catalog/cards/fleurs/en_us/pt_br.json | 17 ++- .../catalog/metrics/normalized_sacrebleu.json | 70 ++++++--- .../normalize_text_basic_with_whisper.json | 3 + .../normalize_text_basic_with_whisper.json | 8 + .../speech_recognition_multilingual.json | 21 +++ .../speech_recognition/multilingual.json | 8 + src/unitxt/inference.py | 6 + src/unitxt/processors.py | 16 ++ src/unitxt/stream_operators.py | 93 ++++++++++++ src/unitxt/string_operators.py | 14 ++ 45 files changed, 1381 insertions(+), 80 deletions(-) create mode 100644 examples/evaluate_speech_translation_benchmark.py create mode 100644 examples/evaluate_speech_translation_covost2.py create mode 100644 examples/evaluate_speech_translation_fleurs.py create mode 100644 prepare/cards/commonvoice.py create mode 100644 prepare/cards/covost2.py rename prepare/cards/{fleures.py => fleurs.py} (69%) create mode 100644 src/unitxt/catalog/cards/commonvoice/de.json create mode 100644 src/unitxt/catalog/cards/commonvoice/en.json create mode 100644 src/unitxt/catalog/cards/commonvoice/es.json create mode 100644 src/unitxt/catalog/cards/commonvoice/fr.json create mode 100644 src/unitxt/catalog/cards/commonvoice/pt.json create mode 100644 src/unitxt/catalog/cards/covost2/from_en/en_de.json create mode 100644 src/unitxt/catalog/cards/covost2/from_en/en_ja.json create mode 100644 src/unitxt/catalog/cards/covost2/to_en/de_en.json create mode 100644 src/unitxt/catalog/cards/covost2/to_en/es_en.json create mode 100644 src/unitxt/catalog/cards/covost2/to_en/fr_en.json create mode 100644 src/unitxt/catalog/cards/covost2/to_en/pt_en.json create mode 100644 src/unitxt/catalog/cards/fleurs/en_us/cmn_hans_cn.json create mode 100644 src/unitxt/catalog/operators/normalize_text_basic_with_whisper.json create mode 100644 src/unitxt/catalog/processors/normalize_text_basic_with_whisper.json create mode 100644 src/unitxt/catalog/tasks/speech_recognition_multilingual.json create mode 100644 src/unitxt/catalog/templates/speech_recognition/multilingual.json diff --git a/examples/evaluate_speech_recognition.py b/examples/evaluate_speech_recognition.py index b6c33310e5..ff1c58c516 100644 --- a/examples/evaluate_speech_recognition.py +++ b/examples/evaluate_speech_recognition.py @@ -1,32 +1,61 @@ +# this python script shows an example of running speech recognition evaluation for Granite Speech using the Hugging Face ESB datasets and the CommonVoice datasets + import os from unitxt import evaluate, load_dataset from unitxt.inference import ( + CrossProviderInferenceEngine, HFGraniteSpeechInferenceEngine, ) from unitxt.system_prompts import TextualSystemPrompt +USE_RITS = True # whether to use RITS service or running the model locally + test_dataset = load_dataset( - card="cards.fleurs.en_us.pt_br", + # select (uncomment) only one of the following cards (datasets) + # for evaluating a benchmark with multiple cards - see evaluate_speech_recognition_benchmark.py in the same directory (examples) + card="cards.esb.ami", + # card="cards.esb.voxpopuli", + # card="cards.esb.librispeech", + # card="cards.esb.spgispeech", + # card="cards.esb.earnings22", + # card="cards.esb.tedlium", + # card="cards.commonvoice.en" + # card="cards.commonvoice.de" + # card="cards.commonvoice.fr" + # card="cards.commonvoice.es" + # card="cards.commonvoice.pt" split="test", format="formats.chat_api", - max_test_instances=10, + max_test_instances=5, # to tun limited part of the test set system_prompt=TextualSystemPrompt( - text="Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant" + text="Knowledge Cutoff Date: April 2024.\nToday's Date: April 9, 2025.\nYou are Granite, developed by IBM. You are a helpful AI assistant" ), ) if os.environ.get("SKIP_HEAVY_LOCAL", False): exit() -model = HFGraniteSpeechInferenceEngine( - model_name="ibm-granite/granite-speech-3.3-2b", - revision="granite-speech-3.3.2-2b", - max_new_tokens=200, -) +if not USE_RITS: + # locally running the model, it needs GPU to run properly + model = HFGraniteSpeechInferenceEngine( + model_name="ibm-granite/granite-speech-3.3-8b", # two options for Granite Speech 3.3: 2b and 8b + revision="granite-speech-3.3.2-2b", + max_new_tokens=120, # 200 for 2b, 120 for 8b + ) +if USE_RITS: + # using the RITS remote service for inferencing + model = CrossProviderInferenceEngine( + model="granite-speech-3-3-8b", # in RITS only the 8b version of Granite Speech is available + provider="rits", + # provider_specific_args={"rits": {"max_new_tokens": 120}}, + max_new_tokens=120, + ) predictions = model(test_dataset) -results = evaluate(predictions=predictions, data=test_dataset) +results = evaluate( + predictions=predictions, data=test_dataset, calc_confidence_intervals=False +) print("Global scores:") print(results.global_scores.summary) diff --git a/examples/evaluate_speech_recognition_benchmark.py b/examples/evaluate_speech_recognition_benchmark.py index 24c6d6d3f8..bfd18776db 100644 --- a/examples/evaluate_speech_recognition_benchmark.py +++ b/examples/evaluate_speech_recognition_benchmark.py @@ -1,3 +1,10 @@ +# this python script shows an example of running speech recognition benchmark evaluation for Granite Speech +# using the Hugging Face ESB datasets (English) and the multilingial CommonVoice datasets + +# to run on a single test set use subset=... below; the list of subsets is: +# voxpopuli, ami, librispeech, spgispeech, tedlium, earnings22, +# commonvoice_en, commonvoice_de, commonvoice_es, commonvoice_fr, commonvoice_pt + import os from unitxt import evaluate, load_dataset @@ -8,10 +15,10 @@ dataset = load_dataset( "benchmarks.speech_recognition", - max_samples_per_subset=5, - # subset="ami", + max_samples_per_subset=5, # while this is commented out, the entire test set is used + # subset="ami", #to tun only a single dataset system_prompt=TextualSystemPrompt( - text="Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant" + text="Knowledge Cutoff Date: April 2024.\nToday's Date: April 9, 2025.\nYou are Granite, developed by IBM. You are a helpful AI assistant" ), split="test", ) @@ -21,9 +28,9 @@ exit() model = HFGraniteSpeechInferenceEngine( - model_name="ibm-granite/granite-speech-3.3-2b", + model_name="ibm-granite/granite-speech-3.3-2b", # two options for Granite Speech 3.3: 2b and 8b revision="granite-speech-3.3.2-2b", - max_new_tokens=200, + max_new_tokens=200, # 200 for 2b, 120 for 8b ) predictions = model(dataset) diff --git a/examples/evaluate_speech_translation_benchmark.py b/examples/evaluate_speech_translation_benchmark.py new file mode 100644 index 0000000000..373ad62f67 --- /dev/null +++ b/examples/evaluate_speech_translation_benchmark.py @@ -0,0 +1,38 @@ +# this python script shows an example of running speech translation benchmark evaluation for Granite Speech +# using the Fleurs and Covost2 datasets + +# to run on a single test set use subset=... below; the list of subsets is: +# fleurs_en_de, fleurs_en_es, fleurs_en_fr, fleurs_en_it, fleurs_en_ja, fleurs_en_pt, fleurs_en_pt, +# covost2_en_de, covost2_en_ja, covost2_de_en, covost2_es_en, covost2_fr_en, covost2_pt_en + +from unitxt import evaluate, load_dataset +from unitxt.inference import ( + HFGraniteSpeechInferenceEngine, +) +from unitxt.system_prompts import TextualSystemPrompt + +dataset = load_dataset( + "benchmarks.speech_translation", + # max_samples_per_subset=100, # while this is commented out, the entire test set is used + # subset="fleurs_en_fr", #to run only a single test set + system_prompt=TextualSystemPrompt( + text="Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant" + ), + split="test", +) + +model = HFGraniteSpeechInferenceEngine( + model_name="ibm-granite/granite-speech-3.3-8b", # two options for Granite Speech 3.3: 2b and 8b + max_new_tokens=120, # 200 for 2b, 120 for 8b +) + +predictions = model(dataset) +results = evaluate( + predictions=predictions, data=dataset, calc_confidence_intervals=False +) + +print("Global scores:") +print(results.global_scores.summary) + +print("Subsets scores:") +print(results.subsets_scores.summary) diff --git a/examples/evaluate_speech_translation_covost2.py b/examples/evaluate_speech_translation_covost2.py new file mode 100644 index 0000000000..15016fc788 --- /dev/null +++ b/examples/evaluate_speech_translation_covost2.py @@ -0,0 +1,58 @@ +# this python script shows an example of running speech translation evaluation for Granite Speech + +from unitxt import evaluate, load_dataset +from unitxt.inference import ( + HFGraniteSpeechInferenceEngine, +) +from unitxt.system_prompts import TextualSystemPrompt + +debug = True # True for extra printing, set to False when commenting out max_test_instances below +max_test_instances = 8 + +# the available calanguages for the covost2 dataset dataset, are: +# translation from English to target language: +# de German +# ja Japanese +# translation from source language to English: +# de German +# es Spanish +# fr French +# pt Portuguese +test_dataset = load_dataset( # select (un-comment) one of the test sets below + card="cards.covost2.from_en.en_de", + # card="cards.covost2.from_en.en_ja", + # card="cards.covost2.to_en.de_en", + # card="cards.covost2.to_en.es_en", + # card="cards.covost2.to_en.fr_en", + # card="cards.covost2.to_en.pt_en", + split="test", + format="formats.chat_api", + max_test_instances=max_test_instances, # comment out for running the entire test + system_prompt=TextualSystemPrompt( + text="Knowledge Cutoff Date: April 2024.\nToday's Date: April 9, 2025.\nYou are Granite, developed by IBM. You are a helpful AI assistant" + ), +) + +if debug: + print(">>>>>>>>>>>>>> first test references >>>>>>>>>>>>") + for idx in range(max_test_instances): + print(f">>>>>> references {idx}: ", test_dataset["references"][idx]) + +model = HFGraniteSpeechInferenceEngine( + model_name="ibm-granite/granite-speech-3.3-8b", # two options for Granite Speech 3.3: 2b and 8b + max_new_tokens=120, # 200 for 2b, 120 for 8b +) + +predictions = model(test_dataset) + +if debug: # print translation reference texts for debug and inspection + print(">>>>>>>>>>>>>> first predictions >>>>>>>>>>>>") + for idx in range(max_test_instances): + print(f">>>>>>>>>>> {idx}: ", predictions[idx]) + +results = evaluate( + predictions=predictions, data=test_dataset, calc_confidence_intervals=False +) + +print("Global scores:") +print(results.global_scores.summary) diff --git a/examples/evaluate_speech_translation_fleurs.py b/examples/evaluate_speech_translation_fleurs.py new file mode 100644 index 0000000000..b278c552d3 --- /dev/null +++ b/examples/evaluate_speech_translation_fleurs.py @@ -0,0 +1,58 @@ +# this python script shows an example of running speech translation evaluation for Granite Speech + +from unitxt import evaluate, load_dataset +from unitxt.inference import ( + HFGraniteSpeechInferenceEngine, +) +from unitxt.system_prompts import TextualSystemPrompt + +debug = False # True for extra printing, set to False when commenting out max_test_instances below +max_test_instances = 20 + +# the available cards for the fleurs dataset, reflecting the target language, are: +# de_de German +# es_419 Spanish, South America +# fr_fr French +# it_it Italian +# ja_jp Japanese +# pt_br Portuguese, Brazil +# cmn_hans_cn Chinese, Mandarin +test_dataset = load_dataset( # select (un-comment) one of the test sets below + # card="cards.fleurs.en_us.de_de", + # card="cards.fleurs.en_us.es_419", + # card="cards.fleurs.en_us.fr_fr", + # card="cards.fleurs.en_us.it_it", + # card="cards.fleurs.en_us.pt_br", + card="cards.fleurs.en_us.ja_jp", + # card="cards.fleurs.en_us.cmn_hans_cn", + split="test", + format="formats.chat_api", + # max_test_instances=max_test_instances, # comment out for running the entire test + system_prompt=TextualSystemPrompt( + text="Knowledge Cutoff Date: April 2024.\nToday's Date: April 9, 2025.\nYou are Granite, developed by IBM. You are a helpful AI assistant" + ), +) + +if debug: + print(">>>>>>>>>>>>>> test references >>>>>>>>>>>>") + for idx in range(max_test_instances): + print(f">>>>>> references {idx}: ", test_dataset["references"][idx]) + +model = HFGraniteSpeechInferenceEngine( + model_name="ibm-granite/granite-speech-3.3-8b", # two options for Granite Speech 3.3: 2b and 8b + max_new_tokens=120, # 200 for 2b, 120 for 8b +) + +predictions = model(test_dataset) + +if debug: # print translation reference texts for debug and inspection + print(">>>>>>>>>>>>>> model predictions >>>>>>>>>>>>") + for idx in range(max_test_instances): + print(f">>>>>>>>>>> {idx}: ", predictions[idx]) + +results = evaluate( + predictions=predictions, data=test_dataset, calc_confidence_intervals=False +) + +print("Global scores:") +print(results.global_scores.summary) diff --git a/prepare/benchmarks/speech_recognition.py b/prepare/benchmarks/speech_recognition.py index 404f37ee20..3ca5afd270 100644 --- a/prepare/benchmarks/speech_recognition.py +++ b/prepare/benchmarks/speech_recognition.py @@ -12,10 +12,6 @@ card="cards.esb.ami", format="formats.chat_api", ), - # "gigaspeech": DatasetRecipe( - # card="cards.esb.gigaspeech", - # format="formats.chat_api", - # ), "librispeech": DatasetRecipe( card="cards.esb.librispeech", format="formats.chat_api", @@ -32,6 +28,26 @@ card="cards.esb.earnings22", format="formats.chat_api", ), + "commonvoice_en": DatasetRecipe( + card="cards.commonvoice.en", + format="formats.chat_api", + ), + "commonvoice_de": DatasetRecipe( + card="cards.commonvoice.de", + format="formats.chat_api", + ), + "commonvoice_es": DatasetRecipe( + card="cards.commonvoice.es", + format="formats.chat_api", + ), + "commonvoice_fr": DatasetRecipe( + card="cards.commonvoice.fr", + format="formats.chat_api", + ), + "commonvoice_pt": DatasetRecipe( + card="cards.commonvoice.pt", + format="formats.chat_api", + ), }, ) diff --git a/prepare/benchmarks/speech_translation.py b/prepare/benchmarks/speech_translation.py index c07d8d8606..3fd68e47c8 100644 --- a/prepare/benchmarks/speech_translation.py +++ b/prepare/benchmarks/speech_translation.py @@ -2,32 +2,63 @@ from unitxt.catalog import add_to_catalog from unitxt.standard import DatasetRecipe +# running benchmarks with fleurs dataset, en-->xx +# running benchmarks with covost2 dataset, en-->xx +# running benchmarks with covost2 dataset, xx-->en benchmark = Benchmark( subsets={ - "en_de": DatasetRecipe( + "fleurs_en_de": DatasetRecipe( card="cards.fleurs.en_us.de_de", format="formats.chat_api", ), - "en_es": DatasetRecipe( + "fleurs_en_es": DatasetRecipe( card="cards.fleurs.en_us.es_419", format="formats.chat_api", ), - "en_fr": DatasetRecipe( + "fleurs_en_fr": DatasetRecipe( card="cards.fleurs.en_us.fr_fr", format="formats.chat_api", ), - "en_it": DatasetRecipe( + "fleurs_en_it": DatasetRecipe( card="cards.fleurs.en_us.it_it", format="formats.chat_api", ), - "en_ja": DatasetRecipe( + "fleurs_en_ja": DatasetRecipe( card="cards.fleurs.en_us.ja_jp", format="formats.chat_api", ), - "en_pt": DatasetRecipe( + "fleurs_en_pt": DatasetRecipe( card="cards.fleurs.en_us.pt_br", format="formats.chat_api", ), + "fleurs_en_zh": DatasetRecipe( + card="cards.fleurs.en_us.cmn_hans_cn", + format="formats.chat_api", + ), + "covost2_en_de": DatasetRecipe( + card="cards.covost2.from_en.en_de", + format="formats.chat_api", + ), + "covost2_en_ja": DatasetRecipe( + card="cards.covost2.from_en.en_ja", + format="formats.chat_api", + ), + "covost2_de_en": DatasetRecipe( + card="cards.covost2.to_en.de_en", + format="formats.chat_api", + ), + "covost2_es_en": DatasetRecipe( + card="cards.covost2.to_en.es_en", + format="formats.chat_api", + ), + "covost2_fr_en": DatasetRecipe( + card="cards.covost2.to_en.fr_en", + format="formats.chat_api", + ), + "covost2_pt_en": DatasetRecipe( + card="cards.covost2.to_en.pt_en", + format="formats.chat_api", + ), }, ) diff --git a/prepare/cards/commonvoice.py b/prepare/cards/commonvoice.py new file mode 100644 index 0000000000..6a0b828f5b --- /dev/null +++ b/prepare/cards/commonvoice.py @@ -0,0 +1,50 @@ +# This Python script is used to prepare cards for the CommonVoice ver. 17 dataset, used for evaluating multilingual speech recognition + +from unitxt.audio_operators import ToAudio +from unitxt.card import TaskCard +from unitxt.catalog import add_to_catalog +from unitxt.loaders import LoadHF +from unitxt.operators import Rename +from unitxt.string_operators import StripQuotation +from unitxt.test_utils.card import test_card + +subsets = ["en", "fr", "de", "es", "pt"] # languages to use +templates_ = { + "en": "templates.speech_recognition.default", + "fr": "templates.speech_recognition.multilingual", + "de": "templates.speech_recognition.multilingual", + "es": "templates.speech_recognition.multilingual", + "pt": "templates.speech_recognition.multilingual", +} +tasks_ = { + "en": "tasks.speech_recognition", + "fr": "tasks.speech_recognition_multilingual", + "de": "tasks.speech_recognition_multilingual", + "es": "tasks.speech_recognition_multilingual", + "pt": "tasks.speech_recognition_multilingual", +} + +first = True +for subset in subsets: + card = TaskCard( + loader=LoadHF( + path="mozilla-foundation/common_voice_17_0", + name=subset, + splits=["test"], + data_classification_policy=["public"], + streaming=True, + ), + preprocess_steps=[ + ToAudio(field="audio"), + Rename(field="sentence", to_field="text"), + StripQuotation(field="text"), + ], + task=tasks_[subset], + templates=[templates_[subset]], + __title__="CommonVoice-17-" + subset, + ) + + if first: + test_card(card, strict=False) + first = False + add_to_catalog(card, f"cards.commonvoice.{subset}", overwrite=True) diff --git a/prepare/cards/covost2.py b/prepare/cards/covost2.py new file mode 100644 index 0000000000..3bbf32b622 --- /dev/null +++ b/prepare/cards/covost2.py @@ -0,0 +1,142 @@ +# This Python script is used to prepare card for the covost2 dataset, used for evaluating speech translation + +from unitxt.audio_operators import ToAudio +from unitxt.blocks import LoadHF, TaskCard +from unitxt.catalog import add_to_catalog +from unitxt.operators import Set +from unitxt.test_utils.card import test_card + +# The covost2 dataset supports translation from 21 languages into English, and from English to 15 languages +# As opposed to the fluers dataset, there is no support from any to any + +# Entire list of supported language pairs (codes) for reference: +# all_subsets_from_en = [ +# "de", +# "tr", +# "fa", +# "sv-SE", +# "mn", +# "zh-CN", +# "cy", +# "ca", +# "sl", +# "et", +# "id", +# "ar", +# "ta", +# "lv", +# "ja" +# ] + +# all_subsets_to_en = [ +# "fr", +# "de", +# "es", +# "ca", +# "it", +# "ru", +# "zh-CN", +# "pt", +# "fa", +# "et", +# "mn", +# "nl", +# "tr", +# "ar", +# "sv-SE", +# "lv", +# "sl", +# "ta", +# "ja", +# "id", +# "cy" +# ] + +# Currently we only use covost2 for evaluating translating to and from english for a limited list as below; additions may follow +# We use the (basic) Whisper text normalization; before extending to other languages, check if Whisper basic normalizer supports them + +subsets_from_en = ["de", "ja"] + +subsets_to_en = ["fr", "de", "es", "pt"] + +lang_name = { + "de": "German", + "ja": "Japanese", + "es": "Spanish", + "fr": "French", + "pt": "Portuguese", + "en": "English", +} + +# An example of how to load the covost2 dataset using Hugging Face's loader: +# dataset = datasets.load_dataset('facebook/covost2', 'en_de', data_dir='/dccstor/aharonsatt/data/covost2/en', split='test') +# +# For each language pair, the source dataset (the one that provides the audio) needs to be available locally +# We are temporarily using the data in the CCC, at the path /dataset/speechdata/CoVoST2/xx where xx is the source language code + +task_types = ["from_en", "to_en"] +local_data_path = "/dataset/speechdata/CoVoST2/" # the CommonVoice ver. 4 datasets are stored locally in the CCC; this local store is required for using HF CoVost 2 dataset + +first = True + +# from English +for lang in subsets_from_en: + card = TaskCard( + loader=LoadHF( + path="facebook/covost2", + name="en_" + lang, + data_dir=local_data_path + "en", + split="test", + streaming=True, + ), + preprocess_steps=[ + ToAudio(field="audio"), + Set( + fields={ + "source_language": lang_name["en"], + "target_language": lang_name[lang], + } + ), + ], + task="tasks.translation.speech", + templates=[ + "templates.translation.speech.default", + ], + ) + + if first: + test_card(card, debug=True) + first = False + + add_to_catalog(card, f"cards.covost2.from_en.en_{lang}", overwrite=True) + +# to English +for lang in subsets_to_en: + card = TaskCard( + loader=LoadHF( + path="facebook/covost2", + name=lang + "_en", + data_dir=local_data_path + lang, + split="test", + streaming=True, + ), + preprocess_steps=[ + ToAudio(field="audio"), + Set( + fields={ + "source_language": lang_name[lang], + "target_language": lang_name["en"], + } + ), + ], + task="tasks.translation.speech", + templates=[ + "templates.translation.speech.default", + ], + ) + + if first: + test_card(card, demos_taken_from="test", num_demos=0) + first = False + + add_to_catalog(card, f"cards.covost2.to_en.{lang}_en", overwrite=True) diff --git a/prepare/cards/fleures.py b/prepare/cards/fleurs.py similarity index 69% rename from prepare/cards/fleures.py rename to prepare/cards/fleurs.py index 82271ee439..f30187107b 100644 --- a/prepare/cards/fleures.py +++ b/prepare/cards/fleurs.py @@ -1,13 +1,16 @@ +# This Python script is used to prepare cards for the fleurs dataset, used for evaluating speech translation + from unitxt.audio_operators import ToAudio from unitxt.blocks import LoadHF, TaskCard from unitxt.catalog import add_to_catalog from unitxt.loaders import MultipleSourceLoader from unitxt.operator import SourceSequentialOperator -from unitxt.operators import RemoveFields, Rename +from unitxt.operators import AddFields, RemoveFields, Rename from unitxt.splitters import RenameSplits -from unitxt.stream_operators import JoinStreams +from unitxt.stream_operators import JoinStreamsFleurs from unitxt.test_utils.card import test_card +# Entire list of languages (codes) supported by fleurs, for reference: # all_subsets = [ # "af_za", # "am_et", @@ -113,6 +116,11 @@ # "zu_za", # ] +# Currently we only use fleurs for evaluating translating from english_us to the following 6 languages; additions may follow +# The fleurs dataset can be used from any language to any language within its supported list of languages +# We use the (basic) Whisper text normalization; before extending to other languages, check if Whisper basic normalizer supports them + +# "to" language: target_subsets = [ "de_de", "es_419", @@ -120,12 +128,36 @@ "it_it", "ja_jp", "pt_br", + "cmn_hans_cn", ] +# "from" language: source_subsets = [ "en_us", ] +lang_code = { + "en_us": "en", + "de_de": "de", + "es_419": "es", + "fr_fr": "fr", + "it_it": "it", + "ja_jp": "ja", + "pt_br": "pt", + "cmn_hans_cn": "zh", +} + +lang_name = { + "en_us": "English", + "de_de": "German", + "es_419": "Spanish", + "fr_fr": "French", + "it_it": "Italian", + "ja_jp": "Japanese", + "pt_br": "Portuguese", + "cmn_hans_cn": "Chinese", +} + first = True for source_subset in source_subsets: for target_subset in target_subsets: @@ -139,6 +171,7 @@ revision="refs/convert/parquet", data_dir=source_subset, splits=["test"], + data_classification_policy=["public"], ), RemoveFields( fields=[ @@ -166,6 +199,7 @@ revision="refs/convert/parquet", data_dir=target_subset, splits=["test"], + data_classification_policy=["public"], ), RemoveFields( fields=[ @@ -175,6 +209,7 @@ "raw_transcription", "gender", "lang_id", + "language", "lang_group_id", ] ), @@ -188,7 +223,14 @@ ] ), preprocess_steps=[ - JoinStreams( + # The following JointStream operator is unique for fleurs + # It removes redundant repetitions in the target ("right") stream - repetitions of the same translation text targets + # The repetitions are inherent to the structure of the fleurs dataset + # On the other hand, we do preserve the repeated recornings of the input speech (1-3 recordings per target instance), + # as they contain recordings by different speakers and different recording conditions, therefore considered as + # separate instannces + # The unique JointStream operator also adds '.' at the end of text instances, to improve the behavior of the metric (sacrebleu) + JoinStreamsFleurs( left_stream="test_input", right_stream="test_target", how="inner", @@ -197,7 +239,12 @@ ), ToAudio(field="audio"), Rename(field="transcription", to_field="translation"), - Rename(field="language", to_field="target_language"), + AddFields( + { + "source_language": lang_name[source_subset], + "target_language": lang_name[target_subset], + } + ), ], task="tasks.translation.speech", templates=[ diff --git a/prepare/metrics/normalized_sacrebleu.py b/prepare/metrics/normalized_sacrebleu.py index b45fd3a6d4..a5243d6486 100644 --- a/prepare/metrics/normalized_sacrebleu.py +++ b/prepare/metrics/normalized_sacrebleu.py @@ -3,6 +3,8 @@ from unitxt.test_utils.metrics import test_metric language_to_tokenizer = { + "italian": None, + "it": None, "german": None, "deutch": None, "de": None, @@ -22,6 +24,8 @@ "ko": "ko-mecab", "japanese": "ja-mecab", "ja": "ja-mecab", + "chinese": "zh", + "zh": "zh", } metric = NormalizedSacrebleu(language_to_tokenizer=language_to_tokenizer) diff --git a/prepare/processors/processors.py b/prepare/processors/processors.py index fa0e47d018..45e1fcb1ba 100644 --- a/prepare/processors/processors.py +++ b/prepare/processors/processors.py @@ -21,6 +21,7 @@ Lower, LowerCaseTillPunc, MatchClosestOption, + NormalizeTextBasicWithWhisper, NormalizeTextWithWhisper, PostProcess, RegexParser, @@ -309,3 +310,9 @@ def add_processor_and_operator_to_catalog( operator=NormalizeTextWithWhisper(), overwrite=True, ) + +add_processor_and_operator_to_catalog( + artifact_name="normalize_text_basic_with_whisper", + operator=NormalizeTextBasicWithWhisper(), + overwrite=True, +) diff --git a/prepare/tasks/speech_recognition.py b/prepare/tasks/speech_recognition.py index 9187687f7b..1ba0d3c8a7 100644 --- a/prepare/tasks/speech_recognition.py +++ b/prepare/tasks/speech_recognition.py @@ -20,3 +20,21 @@ "tasks.speech_recognition", overwrite=True, ) + +add_to_catalog( + Task( + input_fields={ + "audio": Audio, + }, + reference_fields={"text": str}, + prediction_type=str, + metrics=["metrics.wer"], + default_template=InputOutputTemplate( + input_format="{audio}can you transcribe the speech into a written format?", + output_format="{text}", + postprocessors=["processors.normalize_text_basic_with_whisper"], + ), + ), + "tasks.speech_recognition_multilingual", + overwrite=True, +) diff --git a/prepare/tasks/translation/speech.py b/prepare/tasks/translation/speech.py index 81e298f569..aa803f73f7 100644 --- a/prepare/tasks/translation/speech.py +++ b/prepare/tasks/translation/speech.py @@ -1,5 +1,6 @@ from unitxt.blocks import Task from unitxt.catalog import add_to_catalog +from unitxt.templates import InputOutputTemplate from unitxt.types import Audio add_to_catalog( @@ -11,6 +12,12 @@ reference_fields={"translation": str}, prediction_type=str, metrics=["metrics.normalized_sacrebleu"], + default_template=InputOutputTemplate( + # input_format="{audio}listen to the speech and translate it to {target_language}", + input_format="{audio}translate the speech to {target_language}", + output_format="{translation}", + postprocessors=["processors.normalize_text_basic_with_whisper"], + ), ), "tasks.translation.speech", overwrite=True, diff --git a/prepare/templates/speech_recognition/templates.py b/prepare/templates/speech_recognition/templates.py index 54df88edc3..64f1aebc5c 100644 --- a/prepare/templates/speech_recognition/templates.py +++ b/prepare/templates/speech_recognition/templates.py @@ -10,3 +10,13 @@ "templates.speech_recognition.default", overwrite=True, ) + +add_to_catalog( + InputOutputTemplate( + input_format="{audio}can you transcribe the speech into a written format?", + output_format="{text}", + postprocessors=["processors.normalize_text_basic_with_whisper"], + ), + "templates.speech_recognition.multilingual", + overwrite=True, +) diff --git a/prepare/templates/translation/speech.py b/prepare/templates/translation/speech.py index a4d5cdb7ce..508de51c4a 100644 --- a/prepare/templates/translation/speech.py +++ b/prepare/templates/translation/speech.py @@ -3,9 +3,22 @@ add_to_catalog( InputOutputTemplate( - input_format="{audio}listen to the speech and translate it to {target_language}", + # input_format="{audio}listen to the speech and translate it to {target_language}", + input_format="{audio}translate the speech to {target_language}", output_format="{translation}", + postprocessors=["processors.normalize_text_basic_with_whisper"], ), "templates.translation.speech.default", overwrite=True, ) + +add_to_catalog( + InputOutputTemplate( + # input_format="{audio}listen to the speech and translate it to {target_language}", + input_format="{audio}translate the speech to {target_language}", + output_format="{translation}", + # postprocessors=["processors.normalize_text_basic_with_whisper"], + ), + "templates.translation.speech.no_norm", + overwrite=True, +) diff --git a/src/unitxt/catalog/benchmarks/speech_recognition.json b/src/unitxt/catalog/benchmarks/speech_recognition.json index 847235b566..5f44fb76b3 100644 --- a/src/unitxt/catalog/benchmarks/speech_recognition.json +++ b/src/unitxt/catalog/benchmarks/speech_recognition.json @@ -30,6 +30,31 @@ "__type__": "dataset_recipe", "card": "cards.esb.earnings22", "format": "formats.chat_api" + }, + "commonvoice_en": { + "__type__": "dataset_recipe", + "card": "cards.commonvoice.en", + "format": "formats.chat_api" + }, + "commonvoice_de": { + "__type__": "dataset_recipe", + "card": "cards.commonvoice.de", + "format": "formats.chat_api" + }, + "commonvoice_es": { + "__type__": "dataset_recipe", + "card": "cards.commonvoice.es", + "format": "formats.chat_api" + }, + "commonvoice_fr": { + "__type__": "dataset_recipe", + "card": "cards.commonvoice.fr", + "format": "formats.chat_api" + }, + "commonvoice_pt": { + "__type__": "dataset_recipe", + "card": "cards.commonvoice.pt", + "format": "formats.chat_api" } } } diff --git a/src/unitxt/catalog/benchmarks/speech_translation.json b/src/unitxt/catalog/benchmarks/speech_translation.json index 9b1db3ed4a..922e1a4f49 100644 --- a/src/unitxt/catalog/benchmarks/speech_translation.json +++ b/src/unitxt/catalog/benchmarks/speech_translation.json @@ -1,35 +1,70 @@ { "__type__": "benchmark", "subsets": { - "en_de": { + "fleurs_en_de": { "__type__": "dataset_recipe", "card": "cards.fleurs.en_us.de_de", "format": "formats.chat_api" }, - "en_es": { + "fleurs_en_es": { "__type__": "dataset_recipe", "card": "cards.fleurs.en_us.es_419", "format": "formats.chat_api" }, - "en_fr": { + "fleurs_en_fr": { "__type__": "dataset_recipe", "card": "cards.fleurs.en_us.fr_fr", "format": "formats.chat_api" }, - "en_it": { + "fleurs_en_it": { "__type__": "dataset_recipe", "card": "cards.fleurs.en_us.it_it", "format": "formats.chat_api" }, - "en_ja": { + "fleurs_en_ja": { "__type__": "dataset_recipe", "card": "cards.fleurs.en_us.ja_jp", "format": "formats.chat_api" }, - "en_pt": { + "fleurs_en_pt": { "__type__": "dataset_recipe", "card": "cards.fleurs.en_us.pt_br", "format": "formats.chat_api" + }, + "fleurs_en_zh": { + "__type__": "dataset_recipe", + "card": "cards.fleurs.en_us.cmn_hans_cn", + "format": "formats.chat_api" + }, + "covost2_en_de": { + "__type__": "dataset_recipe", + "card": "cards.covost2.from_en.en_de", + "format": "formats.chat_api" + }, + "covost2_en_ja": { + "__type__": "dataset_recipe", + "card": "cards.covost2.from_en.en_ja", + "format": "formats.chat_api" + }, + "covost2_de_en": { + "__type__": "dataset_recipe", + "card": "cards.covost2.to_en.de_en", + "format": "formats.chat_api" + }, + "covost2_es_en": { + "__type__": "dataset_recipe", + "card": "cards.covost2.to_en.es_en", + "format": "formats.chat_api" + }, + "covost2_fr_en": { + "__type__": "dataset_recipe", + "card": "cards.covost2.to_en.fr_en", + "format": "formats.chat_api" + }, + "covost2_pt_en": { + "__type__": "dataset_recipe", + "card": "cards.covost2.to_en.pt_en", + "format": "formats.chat_api" } } } diff --git a/src/unitxt/catalog/cards/commonvoice/de.json b/src/unitxt/catalog/cards/commonvoice/de.json new file mode 100644 index 0000000000..1ad227341b --- /dev/null +++ b/src/unitxt/catalog/cards/commonvoice/de.json @@ -0,0 +1,35 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "mozilla-foundation/common_voice_17_0", + "name": "de", + "splits": [ + "test" + ], + "data_classification_policy": [ + "public" + ], + "streaming": true + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "sentence", + "to_field": "text" + }, + { + "__type__": "strip_quotation", + "field": "text" + } + ], + "task": "tasks.speech_recognition_multilingual", + "templates": [ + "templates.speech_recognition.multilingual" + ], + "__title__": "CommonVoice-17-de" +} diff --git a/src/unitxt/catalog/cards/commonvoice/en.json b/src/unitxt/catalog/cards/commonvoice/en.json new file mode 100644 index 0000000000..6b3471ab4d --- /dev/null +++ b/src/unitxt/catalog/cards/commonvoice/en.json @@ -0,0 +1,35 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "mozilla-foundation/common_voice_17_0", + "name": "en", + "splits": [ + "test" + ], + "data_classification_policy": [ + "public" + ], + "streaming": true + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "sentence", + "to_field": "text" + }, + { + "__type__": "strip_quotation", + "field": "text" + } + ], + "task": "tasks.speech_recognition", + "templates": [ + "templates.speech_recognition.default" + ], + "__title__": "CommonVoice-17-en" +} diff --git a/src/unitxt/catalog/cards/commonvoice/es.json b/src/unitxt/catalog/cards/commonvoice/es.json new file mode 100644 index 0000000000..94134bd543 --- /dev/null +++ b/src/unitxt/catalog/cards/commonvoice/es.json @@ -0,0 +1,35 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "mozilla-foundation/common_voice_17_0", + "name": "es", + "splits": [ + "test" + ], + "data_classification_policy": [ + "public" + ], + "streaming": true + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "sentence", + "to_field": "text" + }, + { + "__type__": "strip_quotation", + "field": "text" + } + ], + "task": "tasks.speech_recognition_multilingual", + "templates": [ + "templates.speech_recognition.multilingual" + ], + "__title__": "CommonVoice-17-es" +} diff --git a/src/unitxt/catalog/cards/commonvoice/fr.json b/src/unitxt/catalog/cards/commonvoice/fr.json new file mode 100644 index 0000000000..fa8d91566c --- /dev/null +++ b/src/unitxt/catalog/cards/commonvoice/fr.json @@ -0,0 +1,35 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "mozilla-foundation/common_voice_17_0", + "name": "fr", + "splits": [ + "test" + ], + "data_classification_policy": [ + "public" + ], + "streaming": true + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "sentence", + "to_field": "text" + }, + { + "__type__": "strip_quotation", + "field": "text" + } + ], + "task": "tasks.speech_recognition_multilingual", + "templates": [ + "templates.speech_recognition.multilingual" + ], + "__title__": "CommonVoice-17-fr" +} diff --git a/src/unitxt/catalog/cards/commonvoice/pt.json b/src/unitxt/catalog/cards/commonvoice/pt.json new file mode 100644 index 0000000000..efd5bb9d96 --- /dev/null +++ b/src/unitxt/catalog/cards/commonvoice/pt.json @@ -0,0 +1,35 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "mozilla-foundation/common_voice_17_0", + "name": "pt", + "splits": [ + "test" + ], + "data_classification_policy": [ + "public" + ], + "streaming": true + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "sentence", + "to_field": "text" + }, + { + "__type__": "strip_quotation", + "field": "text" + } + ], + "task": "tasks.speech_recognition_multilingual", + "templates": [ + "templates.speech_recognition.multilingual" + ], + "__title__": "CommonVoice-17-pt" +} diff --git a/src/unitxt/catalog/cards/covost2/from_en/en_de.json b/src/unitxt/catalog/cards/covost2/from_en/en_de.json new file mode 100644 index 0000000000..b4c6601b8f --- /dev/null +++ b/src/unitxt/catalog/cards/covost2/from_en/en_de.json @@ -0,0 +1,28 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "facebook/covost2", + "name": "en_de", + "data_dir": "/dataset/speechdata/CoVoST2/en", + "split": "test", + "streaming": true + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "set", + "fields": { + "source_language": "English", + "target_language": "German" + } + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/covost2/from_en/en_ja.json b/src/unitxt/catalog/cards/covost2/from_en/en_ja.json new file mode 100644 index 0000000000..8b264f5a27 --- /dev/null +++ b/src/unitxt/catalog/cards/covost2/from_en/en_ja.json @@ -0,0 +1,28 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "facebook/covost2", + "name": "en_ja", + "data_dir": "/dataset/speechdata/CoVoST2/en", + "split": "test", + "streaming": true + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "set", + "fields": { + "source_language": "English", + "target_language": "Japanese" + } + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/covost2/to_en/de_en.json b/src/unitxt/catalog/cards/covost2/to_en/de_en.json new file mode 100644 index 0000000000..ee4f72634b --- /dev/null +++ b/src/unitxt/catalog/cards/covost2/to_en/de_en.json @@ -0,0 +1,28 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "facebook/covost2", + "name": "de_en", + "data_dir": "/dataset/speechdata/CoVoST2/de", + "split": "test", + "streaming": true + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "set", + "fields": { + "source_language": "German", + "target_language": "English" + } + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/covost2/to_en/es_en.json b/src/unitxt/catalog/cards/covost2/to_en/es_en.json new file mode 100644 index 0000000000..9db2555e3e --- /dev/null +++ b/src/unitxt/catalog/cards/covost2/to_en/es_en.json @@ -0,0 +1,28 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "facebook/covost2", + "name": "es_en", + "data_dir": "/dataset/speechdata/CoVoST2/es", + "split": "test", + "streaming": true + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "set", + "fields": { + "source_language": "Spanish", + "target_language": "English" + } + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/covost2/to_en/fr_en.json b/src/unitxt/catalog/cards/covost2/to_en/fr_en.json new file mode 100644 index 0000000000..cbc65de860 --- /dev/null +++ b/src/unitxt/catalog/cards/covost2/to_en/fr_en.json @@ -0,0 +1,28 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "facebook/covost2", + "name": "fr_en", + "data_dir": "/dataset/speechdata/CoVoST2/fr", + "split": "test", + "streaming": true + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "set", + "fields": { + "source_language": "French", + "target_language": "English" + } + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/covost2/to_en/pt_en.json b/src/unitxt/catalog/cards/covost2/to_en/pt_en.json new file mode 100644 index 0000000000..14a9af54d6 --- /dev/null +++ b/src/unitxt/catalog/cards/covost2/to_en/pt_en.json @@ -0,0 +1,28 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "facebook/covost2", + "name": "pt_en", + "data_dir": "/dataset/speechdata/CoVoST2/pt", + "split": "test", + "streaming": true + }, + "preprocess_steps": [ + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "set", + "fields": { + "source_language": "Portuguese", + "target_language": "English" + } + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/fleurs/en_us/cmn_hans_cn.json b/src/unitxt/catalog/cards/fleurs/en_us/cmn_hans_cn.json new file mode 100644 index 0000000000..4b48bf09b3 --- /dev/null +++ b/src/unitxt/catalog/cards/fleurs/en_us/cmn_hans_cn.json @@ -0,0 +1,112 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "multiple_source_loader", + "sources": [ + { + "__type__": "source_sequential_operator", + "steps": [ + { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "en_us", + "splits": [ + "test" + ], + "data_classification_policy": [ + "public" + ] + }, + { + "__type__": "remove_fields", + "fields": [ + "num_samples", + "path", + "raw_transcription", + "transcription", + "gender", + "lang_id", + "language", + "lang_group_id" + ] + }, + { + "__type__": "rename_splits", + "mapper": { + "test": "test_input" + } + } + ] + }, + { + "__type__": "source_sequential_operator", + "steps": [ + { + "__type__": "load_hf", + "path": "google/fleurs", + "revision": "refs/convert/parquet", + "data_dir": "cmn_hans_cn", + "splits": [ + "test" + ], + "data_classification_policy": [ + "public" + ] + }, + { + "__type__": "remove_fields", + "fields": [ + "num_samples", + "path", + "audio", + "raw_transcription", + "gender", + "lang_id", + "language", + "lang_group_id" + ] + }, + { + "__type__": "rename_splits", + "mapper": { + "test": "test_target" + } + } + ] + } + ] + }, + "preprocess_steps": [ + { + "__type__": "join_streams_fleurs", + "left_stream": "test_input", + "right_stream": "test_target", + "how": "inner", + "on": [ + "id" + ], + "new_stream_name": "test" + }, + { + "__type__": "to_audio", + "field": "audio" + }, + { + "__type__": "rename", + "field": "transcription", + "to_field": "translation" + }, + { + "__type__": "add_fields", + "fields": { + "source_language": "English", + "target_language": "Chinese" + } + } + ], + "task": "tasks.translation.speech", + "templates": [ + "templates.translation.speech.default" + ] +} diff --git a/src/unitxt/catalog/cards/fleurs/en_us/de_de.json b/src/unitxt/catalog/cards/fleurs/en_us/de_de.json index 8b00286a50..ac82201a9e 100644 --- a/src/unitxt/catalog/cards/fleurs/en_us/de_de.json +++ b/src/unitxt/catalog/cards/fleurs/en_us/de_de.json @@ -13,6 +13,9 @@ "data_dir": "en_us", "splits": [ "test" + ], + "data_classification_policy": [ + "public" ] }, { @@ -46,6 +49,9 @@ "data_dir": "de_de", "splits": [ "test" + ], + "data_classification_policy": [ + "public" ] }, { @@ -57,6 +63,7 @@ "raw_transcription", "gender", "lang_id", + "language", "lang_group_id" ] }, @@ -72,7 +79,7 @@ }, "preprocess_steps": [ { - "__type__": "join_streams", + "__type__": "join_streams_fleurs", "left_stream": "test_input", "right_stream": "test_target", "how": "inner", @@ -91,9 +98,11 @@ "to_field": "translation" }, { - "__type__": "rename", - "field": "language", - "to_field": "target_language" + "__type__": "add_fields", + "fields": { + "source_language": "English", + "target_language": "German" + } } ], "task": "tasks.translation.speech", diff --git a/src/unitxt/catalog/cards/fleurs/en_us/es_419.json b/src/unitxt/catalog/cards/fleurs/en_us/es_419.json index 3097fd8421..236f7b68c7 100644 --- a/src/unitxt/catalog/cards/fleurs/en_us/es_419.json +++ b/src/unitxt/catalog/cards/fleurs/en_us/es_419.json @@ -13,6 +13,9 @@ "data_dir": "en_us", "splits": [ "test" + ], + "data_classification_policy": [ + "public" ] }, { @@ -46,6 +49,9 @@ "data_dir": "es_419", "splits": [ "test" + ], + "data_classification_policy": [ + "public" ] }, { @@ -57,6 +63,7 @@ "raw_transcription", "gender", "lang_id", + "language", "lang_group_id" ] }, @@ -72,7 +79,7 @@ }, "preprocess_steps": [ { - "__type__": "join_streams", + "__type__": "join_streams_fleurs", "left_stream": "test_input", "right_stream": "test_target", "how": "inner", @@ -91,9 +98,11 @@ "to_field": "translation" }, { - "__type__": "rename", - "field": "language", - "to_field": "target_language" + "__type__": "add_fields", + "fields": { + "source_language": "English", + "target_language": "Spanish" + } } ], "task": "tasks.translation.speech", diff --git a/src/unitxt/catalog/cards/fleurs/en_us/fr_fr.json b/src/unitxt/catalog/cards/fleurs/en_us/fr_fr.json index 503f099fe0..1df10ca566 100644 --- a/src/unitxt/catalog/cards/fleurs/en_us/fr_fr.json +++ b/src/unitxt/catalog/cards/fleurs/en_us/fr_fr.json @@ -13,6 +13,9 @@ "data_dir": "en_us", "splits": [ "test" + ], + "data_classification_policy": [ + "public" ] }, { @@ -46,6 +49,9 @@ "data_dir": "fr_fr", "splits": [ "test" + ], + "data_classification_policy": [ + "public" ] }, { @@ -57,6 +63,7 @@ "raw_transcription", "gender", "lang_id", + "language", "lang_group_id" ] }, @@ -72,7 +79,7 @@ }, "preprocess_steps": [ { - "__type__": "join_streams", + "__type__": "join_streams_fleurs", "left_stream": "test_input", "right_stream": "test_target", "how": "inner", @@ -91,9 +98,11 @@ "to_field": "translation" }, { - "__type__": "rename", - "field": "language", - "to_field": "target_language" + "__type__": "add_fields", + "fields": { + "source_language": "English", + "target_language": "French" + } } ], "task": "tasks.translation.speech", diff --git a/src/unitxt/catalog/cards/fleurs/en_us/it_it.json b/src/unitxt/catalog/cards/fleurs/en_us/it_it.json index fa1f545a90..9b8e09c535 100644 --- a/src/unitxt/catalog/cards/fleurs/en_us/it_it.json +++ b/src/unitxt/catalog/cards/fleurs/en_us/it_it.json @@ -13,6 +13,9 @@ "data_dir": "en_us", "splits": [ "test" + ], + "data_classification_policy": [ + "public" ] }, { @@ -46,6 +49,9 @@ "data_dir": "it_it", "splits": [ "test" + ], + "data_classification_policy": [ + "public" ] }, { @@ -57,6 +63,7 @@ "raw_transcription", "gender", "lang_id", + "language", "lang_group_id" ] }, @@ -72,7 +79,7 @@ }, "preprocess_steps": [ { - "__type__": "join_streams", + "__type__": "join_streams_fleurs", "left_stream": "test_input", "right_stream": "test_target", "how": "inner", @@ -91,9 +98,11 @@ "to_field": "translation" }, { - "__type__": "rename", - "field": "language", - "to_field": "target_language" + "__type__": "add_fields", + "fields": { + "source_language": "English", + "target_language": "Italian" + } } ], "task": "tasks.translation.speech", diff --git a/src/unitxt/catalog/cards/fleurs/en_us/ja_jp.json b/src/unitxt/catalog/cards/fleurs/en_us/ja_jp.json index 31f00a3148..a750d69325 100644 --- a/src/unitxt/catalog/cards/fleurs/en_us/ja_jp.json +++ b/src/unitxt/catalog/cards/fleurs/en_us/ja_jp.json @@ -13,6 +13,9 @@ "data_dir": "en_us", "splits": [ "test" + ], + "data_classification_policy": [ + "public" ] }, { @@ -46,6 +49,9 @@ "data_dir": "ja_jp", "splits": [ "test" + ], + "data_classification_policy": [ + "public" ] }, { @@ -57,6 +63,7 @@ "raw_transcription", "gender", "lang_id", + "language", "lang_group_id" ] }, @@ -72,7 +79,7 @@ }, "preprocess_steps": [ { - "__type__": "join_streams", + "__type__": "join_streams_fleurs", "left_stream": "test_input", "right_stream": "test_target", "how": "inner", @@ -91,9 +98,11 @@ "to_field": "translation" }, { - "__type__": "rename", - "field": "language", - "to_field": "target_language" + "__type__": "add_fields", + "fields": { + "source_language": "English", + "target_language": "Japanese" + } } ], "task": "tasks.translation.speech", diff --git a/src/unitxt/catalog/cards/fleurs/en_us/pt_br.json b/src/unitxt/catalog/cards/fleurs/en_us/pt_br.json index a63cddb177..1531e5ab20 100644 --- a/src/unitxt/catalog/cards/fleurs/en_us/pt_br.json +++ b/src/unitxt/catalog/cards/fleurs/en_us/pt_br.json @@ -13,6 +13,9 @@ "data_dir": "en_us", "splits": [ "test" + ], + "data_classification_policy": [ + "public" ] }, { @@ -46,6 +49,9 @@ "data_dir": "pt_br", "splits": [ "test" + ], + "data_classification_policy": [ + "public" ] }, { @@ -57,6 +63,7 @@ "raw_transcription", "gender", "lang_id", + "language", "lang_group_id" ] }, @@ -72,7 +79,7 @@ }, "preprocess_steps": [ { - "__type__": "join_streams", + "__type__": "join_streams_fleurs", "left_stream": "test_input", "right_stream": "test_target", "how": "inner", @@ -91,9 +98,11 @@ "to_field": "translation" }, { - "__type__": "rename", - "field": "language", - "to_field": "target_language" + "__type__": "add_fields", + "fields": { + "source_language": "English", + "target_language": "Portuguese" + } } ], "task": "tasks.translation.speech", diff --git a/src/unitxt/catalog/metrics/normalized_sacrebleu.json b/src/unitxt/catalog/metrics/normalized_sacrebleu.json index 10ca810322..7017655e2a 100644 --- a/src/unitxt/catalog/metrics/normalized_sacrebleu.json +++ b/src/unitxt/catalog/metrics/normalized_sacrebleu.json @@ -1,24 +1,52 @@ { - "__type__": "normalized_sacrebleu", - "language_to_tokenizer": { - "german": null, - "deutch": null, - "de": null, - "french": null, - "fr": null, - "romanian": null, - "ro": null, - "english": null, - "en": null, - "spanish": null, - "es": null, - "portuguese": null, - "pt": null, - "arabic": "intl", - "ar": "intl", - "korean": "ko-mecab", - "ko": "ko-mecab", - "japanese": "ja-mecab", - "ja": "ja-mecab" + "__type__": "metric_pipeline", + "main_score": "sacrebleu", + "prediction_type": "str", + "preprocess_steps": [ + { + "__type__": "copy", + "field": "task_data/target_language", + "to_field": "task_data/tokenize", + "not_exist_ok": true, + "get_default": "en" + }, + { + "__type__": "lower", + "field": "task_data/tokenize" + }, + { + "__type__": "map_instance_values", + "mappers": { + "task_data/tokenize": { + "italian": null, + "it": null, + "german": null, + "deutch": null, + "de": null, + "french": null, + "fr": null, + "romanian": null, + "ro": null, + "english": null, + "en": null, + "spanish": null, + "es": null, + "portuguese": null, + "pt": null, + "arabic": "intl", + "ar": "intl", + "korean": "ko-mecab", + "ko": "ko-mecab", + "japanese": "ja-mecab", + "ja": "ja-mecab", + "chinese": "zh", + "zh": "zh" + } + }, + "strict": true + } + ], + "metric": { + "__type__": "normalized_sacrebleu" } } diff --git a/src/unitxt/catalog/operators/normalize_text_basic_with_whisper.json b/src/unitxt/catalog/operators/normalize_text_basic_with_whisper.json new file mode 100644 index 0000000000..10c4ace7ae --- /dev/null +++ b/src/unitxt/catalog/operators/normalize_text_basic_with_whisper.json @@ -0,0 +1,3 @@ +{ + "__type__": "normalize_text_basic_with_whisper" +} diff --git a/src/unitxt/catalog/processors/normalize_text_basic_with_whisper.json b/src/unitxt/catalog/processors/normalize_text_basic_with_whisper.json new file mode 100644 index 0000000000..f36e910372 --- /dev/null +++ b/src/unitxt/catalog/processors/normalize_text_basic_with_whisper.json @@ -0,0 +1,8 @@ +{ + "__type__": "post_process", + "process_references": true, + "process_prediction": true, + "operator": { + "__type__": "normalize_text_basic_with_whisper" + } +} diff --git a/src/unitxt/catalog/tasks/speech_recognition_multilingual.json b/src/unitxt/catalog/tasks/speech_recognition_multilingual.json new file mode 100644 index 0000000000..645aa0b8fd --- /dev/null +++ b/src/unitxt/catalog/tasks/speech_recognition_multilingual.json @@ -0,0 +1,21 @@ +{ + "__type__": "task", + "input_fields": { + "audio": "Audio" + }, + "reference_fields": { + "text": "str" + }, + "prediction_type": "str", + "metrics": [ + "metrics.wer" + ], + "default_template": { + "__type__": "input_output_template", + "input_format": "{audio}can you transcribe the speech into a written format?", + "output_format": "{text}", + "postprocessors": [ + "processors.normalize_text_basic_with_whisper" + ] + } +} diff --git a/src/unitxt/catalog/templates/speech_recognition/multilingual.json b/src/unitxt/catalog/templates/speech_recognition/multilingual.json new file mode 100644 index 0000000000..d37c2af89e --- /dev/null +++ b/src/unitxt/catalog/templates/speech_recognition/multilingual.json @@ -0,0 +1,8 @@ +{ + "__type__": "input_output_template", + "input_format": "{audio}can you transcribe the speech into a written format?", + "output_format": "{text}", + "postprocessors": [ + "processors.normalize_text_basic_with_whisper" + ] +} diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index f81919c8c4..4e9a61f242 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -70,6 +70,7 @@ class StandardAPIParamsMixin(Artifact): frequency_penalty: Optional[float] = None presence_penalty: Optional[float] = None max_tokens: Optional[int] = None + max_new_tokens: Optional[int] = None seed: Optional[int] = None stop: Union[Optional[str], List[str]] = None temperature: Optional[float] = None @@ -1727,6 +1728,7 @@ class OpenAiInferenceEngineParamsMixin(Artifact): frequency_penalty: Optional[float] = None presence_penalty: Optional[float] = None max_tokens: Optional[int] = None + max_new_tokens: Optional[int] = None seed: Optional[int] = None stop: Union[Optional[str], List[str]] = None temperature: Optional[float] = None @@ -1744,6 +1746,7 @@ class OpenAiInferenceEngineParams(Artifact): frequency_penalty: Optional[float] = None presence_penalty: Optional[float] = None max_tokens: Optional[int] = None + max_new_tokens: Optional[int] = None seed: Optional[int] = None stop: Union[Optional[str], List[str]] = None temperature: Optional[float] = None @@ -2067,6 +2070,7 @@ def _get_model_name_for_endpoint(cls, model_name: str): class TogetherAiInferenceEngineParamsMixin(Artifact): max_tokens: Optional[int] = None + max_new_tokens: Optional[int] = None stop: Optional[List[str]] = None temperature: Optional[float] = None top_p: Optional[float] = None @@ -2227,6 +2231,7 @@ class WMLChatParamsMixin(Artifact): response_format: Optional[Dict[str, Any]] = None temperature: Optional[float] = None max_tokens: Optional[int] = None + max_new_tokens: Optional[int] = None time_limit: Optional[int] = None top_p: Optional[float] = None n: Optional[int] = None @@ -3184,6 +3189,7 @@ class VLLMParamsMixin(Artifact): bad_words: Optional[List[str]] = None ignore_eos: bool = False max_tokens: Optional[int] = 16 + max_new_tokens: Optional[int] = None min_tokens: int = 0 logprobs: Optional[int] = None prompt_logprobs: Optional[int] = None diff --git a/src/unitxt/processors.py b/src/unitxt/processors.py index ef92a9d778..87aaeb0f2e 100644 --- a/src/unitxt/processors.py +++ b/src/unitxt/processors.py @@ -582,3 +582,19 @@ def prepare(self): def process_value(self, value: str) -> str: return self._normalize(value) + + +class NormalizeTextBasicWithWhisper(FieldOperator): + """A processor that uses uses whisper multilingual normalizer.""" + + _requirements_list = ["transformers"] + + def prepare(self): + super().prepare() + from transformers import WhisperTokenizer + + self.tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base") + self._normalize = self.tokenizer.basic_normalize + + def process_value(self, value: str) -> str: + return self._normalize(value) diff --git a/src/unitxt/stream_operators.py b/src/unitxt/stream_operators.py index b769a5d3b0..c675891cf1 100644 --- a/src/unitxt/stream_operators.py +++ b/src/unitxt/stream_operators.py @@ -127,6 +127,99 @@ def process(self, multi_stream: MultiStream) -> MultiStream: return multi_stream +class JoinStreamsFleurs(MultiStreamOperator): + """Join multiple streams into a single stream - special version for the fleurs speech translation dataset. + + left_stream contains the input speech in English, 1-3 spoken samples per unique text instance + The inference and the evaluation are done on all the spoken samples, including the repetitions, as they + represennt different speakers and different recording conditions, potentially resulting in different + output translations + right_stream contains the target translation text, 1-3 repetitions of the same target text translation + The code below removes the redundant target translations, to avoid duplications during later merging + of the two parts - the left and the right + + Args: + left_stream (str): the left stream. + right_stream (str): the right stream. + how: use "inner". + on: use "id". + new_stream_name (str): The name of the new stream resulting from the merge. + """ + + left_stream: str + right_stream: str + how: Literal["inner"] + on: Optional[List[str]] = None + left_on: Optional[List[str]] = None + right_on: Optional[List[str]] = None + new_stream_name: str + + def merge(self, multi_stream) -> List: + assert self.right_stream in multi_stream and self.left_stream in multi_stream + stream_dict = dict(multi_stream.items()) + left_stream = list(stream_dict[self.left_stream]) + right_stream = list(stream_dict[self.right_stream]) + left_stream_df = pd.DataFrame(left_stream) + # right_stream_df = pd.DataFrame(right_stream) + + # remove duplications from the right stream, this is intended for the FLEURS dataset (spoken translation) + # it removes unnecessary repetitions of the target translations, before merging the input speech with the target translations + right_deduplicate = [] + seen_ids = set() + for item in right_stream: + if item["id"] not in seen_ids: + seen_ids.add(item["id"]) + text = item["transcription"].strip() + if text[-1] not in [".", "!", "?"]: + item["transcription"] = ( + text + "." + ) # add '.' at the end of the reference translations + right_deduplicate.append(item) + right_stream_df = pd.DataFrame(right_deduplicate) + + merged_df = pd.merge( + left_stream_df, + right_stream_df, + how=self.how, + on=self.on, + left_on=self.left_on, + right_on=self.right_on, + ) + + def assert_col_values_are_identical(df: pd.DataFrame, col_name): + (col_name_1, col_name_2) = (f"{col_name}_x", f"{col_name}_y") + if not df.apply( + lambda row: str(row[col_name_1]) == str(row[col_name_2]), + axis=1, + ).all(): + raise UnitxtError( + f"'{col_name}' field is not identical in both left and right instances merged in JoinStreams." + ) + + # If 2 streams / Dataframes contains column with the same names, which are not the columns the join is operated + # on they will be renamed to "[column_name]_x" and "[column_name]_y". Some of these columns are metadsta + # columns that unitxt adds, which must be kept the same. This code verify that all datasets have + # the same metadata values and rename the columns accordingly. + common_cols_to_verify = ["data_classification_policy", "recipe_metadata"] + for common_col in common_cols_to_verify: + assert_col_values_are_identical(merged_df, common_col) + merged_df[common_col] = merged_df[f"{common_col}_x"] + merged_df = merged_df.drop( + columns=[f"{common_col}_x", f"{common_col}_y"], errors="ignore" + ) + + if len(merged_df) == 0: + raise UnitxtError( + f"JoinStreams resulted in an empty stream. It means that that keys in fields '{self.on}' on the left and on right streams do not match the merge policy of '{self.how}'." + ) + return merged_df.to_dict(orient="records") + + def process(self, multi_stream: MultiStream) -> MultiStream: + merged_records = self.merge(multi_stream) + multi_stream[self.new_stream_name] = ListStream(instances_list=merged_records) + return multi_stream + + class DeleteSplits(MultiStreamOperator): """Operator which delete splits in stream. diff --git a/src/unitxt/string_operators.py b/src/unitxt/string_operators.py index 03a1c3b2b3..fd20e17b64 100644 --- a/src/unitxt/string_operators.py +++ b/src/unitxt/string_operators.py @@ -93,6 +93,20 @@ def process_value(self, value: str) -> str: return value.strip() +class StripQuotation(FieldOperator): + def process_value(self, value: str) -> str: + if value.startswith('"') and value.endswith('"'): + return value.strip('"') + return value + + +class AddFullStop(FieldOperator): + def process_value(self, value: str) -> str: + if value[-1] not in [".", "?", "!"]: + return value + "." + return value + + class Replace(FieldOperator): old: str new: str From e10922dfb69b3b6bb12edfdbc8af64fbd4295215 Mon Sep 17 00:00:00 2001 From: "AHARON SATT AHARONSA@il.ibm.com" Date: Thu, 14 Aug 2025 11:58:22 +0300 Subject: [PATCH 29/36] updates --- .../catalog/metrics/normalized_sacrebleu.json | 74 +++++++------------ .../catalog/tasks/translation/speech.json | 10 ++- .../templates/translation/speech/default.json | 7 +- .../templates/translation/speech/no_norm.json | 5 ++ 4 files changed, 44 insertions(+), 52 deletions(-) create mode 100644 src/unitxt/catalog/templates/translation/speech/no_norm.json diff --git a/src/unitxt/catalog/metrics/normalized_sacrebleu.json b/src/unitxt/catalog/metrics/normalized_sacrebleu.json index 7017655e2a..c1cd14c5a5 100644 --- a/src/unitxt/catalog/metrics/normalized_sacrebleu.json +++ b/src/unitxt/catalog/metrics/normalized_sacrebleu.json @@ -1,52 +1,28 @@ { - "__type__": "metric_pipeline", - "main_score": "sacrebleu", - "prediction_type": "str", - "preprocess_steps": [ - { - "__type__": "copy", - "field": "task_data/target_language", - "to_field": "task_data/tokenize", - "not_exist_ok": true, - "get_default": "en" - }, - { - "__type__": "lower", - "field": "task_data/tokenize" - }, - { - "__type__": "map_instance_values", - "mappers": { - "task_data/tokenize": { - "italian": null, - "it": null, - "german": null, - "deutch": null, - "de": null, - "french": null, - "fr": null, - "romanian": null, - "ro": null, - "english": null, - "en": null, - "spanish": null, - "es": null, - "portuguese": null, - "pt": null, - "arabic": "intl", - "ar": "intl", - "korean": "ko-mecab", - "ko": "ko-mecab", - "japanese": "ja-mecab", - "ja": "ja-mecab", - "chinese": "zh", - "zh": "zh" - } - }, - "strict": true - } - ], - "metric": { - "__type__": "normalized_sacrebleu" + "__type__": "normalized_sacrebleu", + "language_to_tokenizer": { + "italian": null, + "it": null, + "german": null, + "deutch": null, + "de": null, + "french": null, + "fr": null, + "romanian": null, + "ro": null, + "english": null, + "en": null, + "spanish": null, + "es": null, + "portuguese": null, + "pt": null, + "arabic": "intl", + "ar": "intl", + "korean": "ko-mecab", + "ko": "ko-mecab", + "japanese": "ja-mecab", + "ja": "ja-mecab", + "chinese": "zh", + "zh": "zh" } } diff --git a/src/unitxt/catalog/tasks/translation/speech.json b/src/unitxt/catalog/tasks/translation/speech.json index 23153d9366..0fcb8a0c39 100644 --- a/src/unitxt/catalog/tasks/translation/speech.json +++ b/src/unitxt/catalog/tasks/translation/speech.json @@ -10,5 +10,13 @@ "prediction_type": "str", "metrics": [ "metrics.normalized_sacrebleu" - ] + ], + "default_template": { + "__type__": "input_output_template", + "input_format": "{audio}translate the speech to {target_language}", + "output_format": "{translation}", + "postprocessors": [ + "processors.normalize_text_basic_with_whisper" + ] + } } diff --git a/src/unitxt/catalog/templates/translation/speech/default.json b/src/unitxt/catalog/templates/translation/speech/default.json index 29f080829b..2b475ea91a 100644 --- a/src/unitxt/catalog/templates/translation/speech/default.json +++ b/src/unitxt/catalog/templates/translation/speech/default.json @@ -1,5 +1,8 @@ { "__type__": "input_output_template", - "input_format": "{audio}listen to the speech and translate it to {target_language}", - "output_format": "{translation}" + "input_format": "{audio}translate the speech to {target_language}", + "output_format": "{translation}", + "postprocessors": [ + "processors.normalize_text_basic_with_whisper" + ] } diff --git a/src/unitxt/catalog/templates/translation/speech/no_norm.json b/src/unitxt/catalog/templates/translation/speech/no_norm.json new file mode 100644 index 0000000000..570ff18c29 --- /dev/null +++ b/src/unitxt/catalog/templates/translation/speech/no_norm.json @@ -0,0 +1,5 @@ +{ + "__type__": "input_output_template", + "input_format": "{audio}translate the speech to {target_language}", + "output_format": "{translation}" +} From 4d501037d61a5f57598821de5cc7da6560477288 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Tue, 19 Aug 2025 09:12:45 +0300 Subject: [PATCH 30/36] make covost tests run only when files exist Signed-off-by: elronbandel --- prepare/cards/covost2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/prepare/cards/covost2.py b/prepare/cards/covost2.py index 3bbf32b622..044bc2387d 100644 --- a/prepare/cards/covost2.py +++ b/prepare/cards/covost2.py @@ -1,5 +1,7 @@ # This Python script is used to prepare card for the covost2 dataset, used for evaluating speech translation +import os + from unitxt.audio_operators import ToAudio from unitxt.blocks import LoadHF, TaskCard from unitxt.catalog import add_to_catalog @@ -88,6 +90,7 @@ data_dir=local_data_path + "en", split="test", streaming=True, + requirements=["datasets<4.0.0"], ), preprocess_steps=[ ToAudio(field="audio"), @@ -135,7 +138,7 @@ ], ) - if first: + if first and os.path.isdir(local_data_path): test_card(card, demos_taken_from="test", num_demos=0) first = False From 3ff1fa6fc684d9ce6f4724d2da4abffa34943270 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Tue, 19 Aug 2025 19:08:39 +0300 Subject: [PATCH 31/36] Add directory check before testing card Signed-off-by: elronbandel --- prepare/cards/covost2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prepare/cards/covost2.py b/prepare/cards/covost2.py index 044bc2387d..b713bbd873 100644 --- a/prepare/cards/covost2.py +++ b/prepare/cards/covost2.py @@ -90,7 +90,7 @@ data_dir=local_data_path + "en", split="test", streaming=True, - requirements=["datasets<4.0.0"], + # requirements=["datasets<4.0.0"], ), preprocess_steps=[ ToAudio(field="audio"), @@ -107,7 +107,7 @@ ], ) - if first: + if first and os.path.isdir(local_data_path): test_card(card, debug=True) first = False From a3a3cecd2bbb701a8e8b3fcfd986796acad552c2 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Wed, 20 Aug 2025 14:14:42 +0300 Subject: [PATCH 32/36] Use parquet files for speed Signed-off-by: elronbandel --- prepare/cards/commonvoice.py | 3 ++- src/unitxt/catalog/cards/commonvoice/de.json | 3 ++- src/unitxt/catalog/cards/commonvoice/en.json | 3 ++- src/unitxt/catalog/cards/commonvoice/es.json | 3 ++- src/unitxt/catalog/cards/commonvoice/fr.json | 3 ++- src/unitxt/catalog/cards/commonvoice/pt.json | 3 ++- 6 files changed, 12 insertions(+), 6 deletions(-) diff --git a/prepare/cards/commonvoice.py b/prepare/cards/commonvoice.py index 6a0b828f5b..f04e2d1e65 100644 --- a/prepare/cards/commonvoice.py +++ b/prepare/cards/commonvoice.py @@ -29,7 +29,8 @@ card = TaskCard( loader=LoadHF( path="mozilla-foundation/common_voice_17_0", - name=subset, + revision="refs/convert/parquet", + data_dir=subset, splits=["test"], data_classification_policy=["public"], streaming=True, diff --git a/src/unitxt/catalog/cards/commonvoice/de.json b/src/unitxt/catalog/cards/commonvoice/de.json index 1ad227341b..39046a5268 100644 --- a/src/unitxt/catalog/cards/commonvoice/de.json +++ b/src/unitxt/catalog/cards/commonvoice/de.json @@ -3,7 +3,8 @@ "loader": { "__type__": "load_hf", "path": "mozilla-foundation/common_voice_17_0", - "name": "de", + "revision": "refs/convert/parquet", + "data_dir": "de", "splits": [ "test" ], diff --git a/src/unitxt/catalog/cards/commonvoice/en.json b/src/unitxt/catalog/cards/commonvoice/en.json index 6b3471ab4d..96cab40c14 100644 --- a/src/unitxt/catalog/cards/commonvoice/en.json +++ b/src/unitxt/catalog/cards/commonvoice/en.json @@ -3,7 +3,8 @@ "loader": { "__type__": "load_hf", "path": "mozilla-foundation/common_voice_17_0", - "name": "en", + "revision": "refs/convert/parquet", + "data_dir": "en", "splits": [ "test" ], diff --git a/src/unitxt/catalog/cards/commonvoice/es.json b/src/unitxt/catalog/cards/commonvoice/es.json index 94134bd543..7c147bb8bd 100644 --- a/src/unitxt/catalog/cards/commonvoice/es.json +++ b/src/unitxt/catalog/cards/commonvoice/es.json @@ -3,7 +3,8 @@ "loader": { "__type__": "load_hf", "path": "mozilla-foundation/common_voice_17_0", - "name": "es", + "revision": "refs/convert/parquet", + "data_dir": "es", "splits": [ "test" ], diff --git a/src/unitxt/catalog/cards/commonvoice/fr.json b/src/unitxt/catalog/cards/commonvoice/fr.json index fa8d91566c..763b2a798a 100644 --- a/src/unitxt/catalog/cards/commonvoice/fr.json +++ b/src/unitxt/catalog/cards/commonvoice/fr.json @@ -3,7 +3,8 @@ "loader": { "__type__": "load_hf", "path": "mozilla-foundation/common_voice_17_0", - "name": "fr", + "revision": "refs/convert/parquet", + "data_dir": "fr", "splits": [ "test" ], diff --git a/src/unitxt/catalog/cards/commonvoice/pt.json b/src/unitxt/catalog/cards/commonvoice/pt.json index efd5bb9d96..d1d2c4d438 100644 --- a/src/unitxt/catalog/cards/commonvoice/pt.json +++ b/src/unitxt/catalog/cards/commonvoice/pt.json @@ -3,7 +3,8 @@ "loader": { "__type__": "load_hf", "path": "mozilla-foundation/common_voice_17_0", - "name": "pt", + "revision": "refs/convert/parquet", + "data_dir": "pt", "splits": [ "test" ], From f2655cd8fd5bb636b2f59083c8edf8d952c1ae79 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Wed, 20 Aug 2025 14:17:24 +0300 Subject: [PATCH 33/36] add missing splits for completeness Signed-off-by: elronbandel --- prepare/cards/commonvoice.py | 9 ++++++++- src/unitxt/catalog/cards/commonvoice/de.json | 11 ++++++++--- src/unitxt/catalog/cards/commonvoice/en.json | 11 ++++++++--- src/unitxt/catalog/cards/commonvoice/es.json | 11 ++++++++--- src/unitxt/catalog/cards/commonvoice/fr.json | 11 ++++++++--- src/unitxt/catalog/cards/commonvoice/pt.json | 11 ++++++++--- 6 files changed, 48 insertions(+), 16 deletions(-) diff --git a/prepare/cards/commonvoice.py b/prepare/cards/commonvoice.py index f04e2d1e65..2c149dce97 100644 --- a/prepare/cards/commonvoice.py +++ b/prepare/cards/commonvoice.py @@ -31,8 +31,15 @@ path="mozilla-foundation/common_voice_17_0", revision="refs/convert/parquet", data_dir=subset, - splits=["test"], data_classification_policy=["public"], + splits=[ + "invalidated", + "other", + "test", + "train", + "validated", + "validation", + ], streaming=True, ), preprocess_steps=[ diff --git a/src/unitxt/catalog/cards/commonvoice/de.json b/src/unitxt/catalog/cards/commonvoice/de.json index 39046a5268..721f7d8791 100644 --- a/src/unitxt/catalog/cards/commonvoice/de.json +++ b/src/unitxt/catalog/cards/commonvoice/de.json @@ -5,12 +5,17 @@ "path": "mozilla-foundation/common_voice_17_0", "revision": "refs/convert/parquet", "data_dir": "de", - "splits": [ - "test" - ], "data_classification_policy": [ "public" ], + "splits": [ + "invalidated", + "other", + "test", + "train", + "validated", + "validation" + ], "streaming": true }, "preprocess_steps": [ diff --git a/src/unitxt/catalog/cards/commonvoice/en.json b/src/unitxt/catalog/cards/commonvoice/en.json index 96cab40c14..1b6fc08ed8 100644 --- a/src/unitxt/catalog/cards/commonvoice/en.json +++ b/src/unitxt/catalog/cards/commonvoice/en.json @@ -5,12 +5,17 @@ "path": "mozilla-foundation/common_voice_17_0", "revision": "refs/convert/parquet", "data_dir": "en", - "splits": [ - "test" - ], "data_classification_policy": [ "public" ], + "splits": [ + "invalidated", + "other", + "test", + "train", + "validated", + "validation" + ], "streaming": true }, "preprocess_steps": [ diff --git a/src/unitxt/catalog/cards/commonvoice/es.json b/src/unitxt/catalog/cards/commonvoice/es.json index 7c147bb8bd..f46757b049 100644 --- a/src/unitxt/catalog/cards/commonvoice/es.json +++ b/src/unitxt/catalog/cards/commonvoice/es.json @@ -5,12 +5,17 @@ "path": "mozilla-foundation/common_voice_17_0", "revision": "refs/convert/parquet", "data_dir": "es", - "splits": [ - "test" - ], "data_classification_policy": [ "public" ], + "splits": [ + "invalidated", + "other", + "test", + "train", + "validated", + "validation" + ], "streaming": true }, "preprocess_steps": [ diff --git a/src/unitxt/catalog/cards/commonvoice/fr.json b/src/unitxt/catalog/cards/commonvoice/fr.json index 763b2a798a..a2a401f866 100644 --- a/src/unitxt/catalog/cards/commonvoice/fr.json +++ b/src/unitxt/catalog/cards/commonvoice/fr.json @@ -5,12 +5,17 @@ "path": "mozilla-foundation/common_voice_17_0", "revision": "refs/convert/parquet", "data_dir": "fr", - "splits": [ - "test" - ], "data_classification_policy": [ "public" ], + "splits": [ + "invalidated", + "other", + "test", + "train", + "validated", + "validation" + ], "streaming": true }, "preprocess_steps": [ diff --git a/src/unitxt/catalog/cards/commonvoice/pt.json b/src/unitxt/catalog/cards/commonvoice/pt.json index d1d2c4d438..8cc207a00f 100644 --- a/src/unitxt/catalog/cards/commonvoice/pt.json +++ b/src/unitxt/catalog/cards/commonvoice/pt.json @@ -5,12 +5,17 @@ "path": "mozilla-foundation/common_voice_17_0", "revision": "refs/convert/parquet", "data_dir": "pt", - "splits": [ - "test" - ], "data_classification_policy": [ "public" ], + "splits": [ + "invalidated", + "other", + "test", + "train", + "validated", + "validation" + ], "streaming": true }, "preprocess_steps": [ From d6e0f00283895a27edb6db0fc9a9664795838306 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Sun, 24 Aug 2025 20:46:57 +0300 Subject: [PATCH 34/36] fix Signed-off-by: elronbandel --- src/unitxt/inference.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 4e9a61f242..7c15e49536 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -612,6 +612,12 @@ def get_logprobs( logprobs: List[List[Dict[str, Any]]] = [] + tokenizer = ( + self.processor.tokenizer + if hasattr(self.processor, "tokenizer") + else self.processor + ) + for sample_no, sample_scores in enumerate(transition_scores.detach().cpu()): sample_logprobs: List[Dict[str, Any]] = [] @@ -622,7 +628,7 @@ def get_logprobs( "logprob": float(score.cpu()), "top_tokens": [ { - "text": self.processor.tokenizer.decode(idx), + "text": tokenizer.decode(idx), "logprob": float( predictions.scores[n][sample_no][idx].cpu() ), From 3d45b482ca06c6661ddc98818462d68072127418 Mon Sep 17 00:00:00 2001 From: "AHARON SATT AHARONSA@il.ibm.com" Date: Sun, 31 Aug 2025 11:21:55 +0300 Subject: [PATCH 35/36] change to small ASR datasets --- examples/evaluate_speech_recognition.py | 16 +++- .../evaluate_speech_recognition_benchmark.py | 6 +- .../evaluate_speech_translation_benchmark.py | 2 +- .../evaluate_speech_translation_covost2.py | 2 +- .../evaluate_speech_translation_fleurs.py | 2 +- .../benchmarks/speech_recognition_english.py | 38 ++++++++ ....py => speech_recognition_multilingual.py} | 4 + prepare/cards/esb/ami.py | 50 ++++++++-- prepare/cards/esb/commonvoice.py | 74 +++++++++++++++ prepare/cards/esb/earnings22.py | 50 ++++++++-- prepare/cards/esb/gigaspeech.py | 50 ++++++++-- prepare/cards/esb/librispeech.py | 52 +++++++++-- prepare/cards/esb/spgispeech.py | 50 ++++++++-- prepare/cards/esb/tedlium.py | 50 ++++++++-- prepare/cards/esb/voxpopuli.py | 50 ++++++++-- .../benchmarks/speech_recognition.json | 24 +---- src/unitxt/catalog/cards/esb/ami.json | 70 ++++++++++++-- src/unitxt/catalog/cards/esb/commonvoice.json | 93 +++++++++++++++++++ src/unitxt/catalog/cards/esb/earnings22.json | 70 ++++++++++++-- src/unitxt/catalog/cards/esb/gigaspeech.json | 70 ++++++++++++-- src/unitxt/catalog/cards/esb/librispeech.json | 70 +++++++++++--- src/unitxt/catalog/cards/esb/spgispeech.json | 70 ++++++++++++-- src/unitxt/catalog/cards/esb/tedlium.json | 70 ++++++++++++-- src/unitxt/catalog/cards/esb/voxpopuli.json | 70 ++++++++++++-- src/unitxt/inference.py | 2 + 25 files changed, 949 insertions(+), 156 deletions(-) create mode 100644 prepare/benchmarks/speech_recognition_english.py rename prepare/benchmarks/{speech_recognition.py => speech_recognition_multilingual.py} (92%) create mode 100644 prepare/cards/esb/commonvoice.py create mode 100644 src/unitxt/catalog/cards/esb/commonvoice.json diff --git a/examples/evaluate_speech_recognition.py b/examples/evaluate_speech_recognition.py index ff1c58c516..1abaaa9362 100644 --- a/examples/evaluate_speech_recognition.py +++ b/examples/evaluate_speech_recognition.py @@ -9,7 +9,8 @@ ) from unitxt.system_prompts import TextualSystemPrompt -USE_RITS = True # whether to use RITS service or running the model locally +USE_RITS = False # whether to use RITS service +USE_WML = False # whether to use WML service test_dataset = load_dataset( # select (uncomment) only one of the following cards (datasets) @@ -36,21 +37,26 @@ if os.environ.get("SKIP_HEAVY_LOCAL", False): exit() -if not USE_RITS: +if not USE_RITS and not USE_WML: # locally running the model, it needs GPU to run properly model = HFGraniteSpeechInferenceEngine( model_name="ibm-granite/granite-speech-3.3-8b", # two options for Granite Speech 3.3: 2b and 8b revision="granite-speech-3.3.2-2b", - max_new_tokens=120, # 200 for 2b, 120 for 8b + max_new_tokens=200, ) if USE_RITS: # using the RITS remote service for inferencing model = CrossProviderInferenceEngine( model="granite-speech-3-3-8b", # in RITS only the 8b version of Granite Speech is available provider="rits", - # provider_specific_args={"rits": {"max_new_tokens": 120}}, - max_new_tokens=120, + # provider_specific_args={"rits": {"max_new_tokens": 200}}, + max_new_tokens=200, ) +if USE_WML: + # using the WML remote service for inferencing + # code to be completed + model = None + predictions = model(test_dataset) results = evaluate( diff --git a/examples/evaluate_speech_recognition_benchmark.py b/examples/evaluate_speech_recognition_benchmark.py index bfd18776db..1e4808daac 100644 --- a/examples/evaluate_speech_recognition_benchmark.py +++ b/examples/evaluate_speech_recognition_benchmark.py @@ -30,11 +30,13 @@ model = HFGraniteSpeechInferenceEngine( model_name="ibm-granite/granite-speech-3.3-2b", # two options for Granite Speech 3.3: 2b and 8b revision="granite-speech-3.3.2-2b", - max_new_tokens=200, # 200 for 2b, 120 for 8b + max_new_tokens=200, ) predictions = model(dataset) -results = evaluate(predictions=predictions, data=dataset) +results = evaluate( + predictions=predictions, data=dataset, calc_confidence_intervals=False +) print("Global scores:") print(results.global_scores.summary) diff --git a/examples/evaluate_speech_translation_benchmark.py b/examples/evaluate_speech_translation_benchmark.py index 373ad62f67..7c2deb4d68 100644 --- a/examples/evaluate_speech_translation_benchmark.py +++ b/examples/evaluate_speech_translation_benchmark.py @@ -23,7 +23,7 @@ model = HFGraniteSpeechInferenceEngine( model_name="ibm-granite/granite-speech-3.3-8b", # two options for Granite Speech 3.3: 2b and 8b - max_new_tokens=120, # 200 for 2b, 120 for 8b + max_new_tokens=200, ) predictions = model(dataset) diff --git a/examples/evaluate_speech_translation_covost2.py b/examples/evaluate_speech_translation_covost2.py index 15016fc788..775e0b3df1 100644 --- a/examples/evaluate_speech_translation_covost2.py +++ b/examples/evaluate_speech_translation_covost2.py @@ -40,7 +40,7 @@ model = HFGraniteSpeechInferenceEngine( model_name="ibm-granite/granite-speech-3.3-8b", # two options for Granite Speech 3.3: 2b and 8b - max_new_tokens=120, # 200 for 2b, 120 for 8b + max_new_tokens=200, ) predictions = model(test_dataset) diff --git a/examples/evaluate_speech_translation_fleurs.py b/examples/evaluate_speech_translation_fleurs.py index b278c552d3..574ec6c2eb 100644 --- a/examples/evaluate_speech_translation_fleurs.py +++ b/examples/evaluate_speech_translation_fleurs.py @@ -40,7 +40,7 @@ model = HFGraniteSpeechInferenceEngine( model_name="ibm-granite/granite-speech-3.3-8b", # two options for Granite Speech 3.3: 2b and 8b - max_new_tokens=120, # 200 for 2b, 120 for 8b + max_new_tokens=200, ) predictions = model(test_dataset) diff --git a/prepare/benchmarks/speech_recognition_english.py b/prepare/benchmarks/speech_recognition_english.py new file mode 100644 index 0000000000..114adf48f8 --- /dev/null +++ b/prepare/benchmarks/speech_recognition_english.py @@ -0,0 +1,38 @@ +from unitxt.benchmark import Benchmark +from unitxt.catalog import add_to_catalog +from unitxt.standard import DatasetRecipe + +benchmark = Benchmark( + subsets={ + "voxpopuli": DatasetRecipe( + card="cards.esb.voxpopuli", + format="formats.chat_api", + ), + "ami": DatasetRecipe( + card="cards.esb.ami", + format="formats.chat_api", + ), + "librispeech": DatasetRecipe( + card="cards.esb.librispeech", + format="formats.chat_api", + ), + "spgispeech": DatasetRecipe( + card="cards.esb.spgispeech", + format="formats.chat_api", + ), + "tedlium": DatasetRecipe( + card="cards.esb.tedlium", + format="formats.chat_api", + ), + "earnings22": DatasetRecipe( + card="cards.esb.earnings22", + format="formats.chat_api", + ), + "commonvoice": DatasetRecipe( + card="cards.esb.commonvoice", + format="formats.chat_api", + ), + }, +) + +add_to_catalog(benchmark, "benchmarks.speech_recognition", overwrite=True) diff --git a/prepare/benchmarks/speech_recognition.py b/prepare/benchmarks/speech_recognition_multilingual.py similarity index 92% rename from prepare/benchmarks/speech_recognition.py rename to prepare/benchmarks/speech_recognition_multilingual.py index 3ca5afd270..477f22886b 100644 --- a/prepare/benchmarks/speech_recognition.py +++ b/prepare/benchmarks/speech_recognition_multilingual.py @@ -28,6 +28,10 @@ card="cards.esb.earnings22", format="formats.chat_api", ), + "commonvoice": DatasetRecipe( + card="cards.esb.commonvoice", + format="formats.chat_api", + ), "commonvoice_en": DatasetRecipe( card="cards.commonvoice.en", format="formats.chat_api", diff --git a/prepare/cards/esb/ami.py b/prepare/cards/esb/ami.py index 565bcea6be..012d1f77bc 100644 --- a/prepare/cards/esb/ami.py +++ b/prepare/cards/esb/ami.py @@ -1,19 +1,55 @@ from unitxt.audio_operators import ToAudio from unitxt.card import TaskCard from unitxt.catalog import add_to_catalog -from unitxt.loaders import LoadHF +from unitxt.loaders import LoadHF, MultipleSourceLoader +from unitxt.operator import SourceSequentialOperator +from unitxt.operators import Rename +from unitxt.splitters import RenameSplits, SplitRandomMix from unitxt.test_utils.card import test_card +# the datasets in 'esb/diagnostic-dataset' contain two splits: 'clean' and 'other' +# the script below downloads the two splits separately per dataset, then combines them into a single split named 'test' card = TaskCard( - loader=LoadHF( - path="hf-audio/esb-datasets-test-only-sorted", - name="ami", - splits=["test"], - data_classification_policy=["public"], - streaming=True, + loader=MultipleSourceLoader( + sources=[ + SourceSequentialOperator( + steps=[ + LoadHF( + path="esb/diagnostic-dataset", + name="ami", + splits=["clean"], + data_classification_policy=["public"], + streaming=True, + ), + RenameSplits( + { + "clean": "test1", + } + ), + ] + ), + SourceSequentialOperator( + steps=[ + LoadHF( + path="esb/diagnostic-dataset", + name="ami", + splits=["other"], + data_classification_policy=["public"], + streaming=True, + ), + RenameSplits( + { + "other": "test2", + } + ), + ] + ), + ] ), preprocess_steps=[ + SplitRandomMix({"test": "test1+test2"}), ToAudio(field="audio"), + Rename(field="norm_transcript", to_field="text"), ], task="tasks.speech_recognition", templates=[ diff --git a/prepare/cards/esb/commonvoice.py b/prepare/cards/esb/commonvoice.py new file mode 100644 index 0000000000..38f4f0fbfc --- /dev/null +++ b/prepare/cards/esb/commonvoice.py @@ -0,0 +1,74 @@ +from unitxt.audio_operators import ToAudio +from unitxt.card import TaskCard +from unitxt.catalog import add_to_catalog +from unitxt.loaders import LoadHF, MultipleSourceLoader +from unitxt.operator import SourceSequentialOperator +from unitxt.operators import Rename +from unitxt.splitters import RenameSplits, SplitRandomMix +from unitxt.test_utils.card import test_card + +# the datasets in 'esb/diagnostic-dataset' contain two splits: 'clean' and 'other' +# the script below downloads the two splits separately per dataset, then combines them into a single split named 'test' +card = TaskCard( + loader=MultipleSourceLoader( + sources=[ + SourceSequentialOperator( + steps=[ + LoadHF( + path="esb/diagnostic-dataset", + name="common_voice", + splits=["clean"], + data_classification_policy=["public"], + streaming=True, + ), + RenameSplits( + { + "clean": "test1", + } + ), + ] + ), + SourceSequentialOperator( + steps=[ + LoadHF( + path="esb/diagnostic-dataset", + name="common_voice", + splits=["other"], + data_classification_policy=["public"], + streaming=True, + ), + RenameSplits( + { + "other": "test2", + } + ), + ] + ), + ] + ), + preprocess_steps=[ + SplitRandomMix({"test": "test1+test2"}), + ToAudio(field="audio"), + Rename(field="norm_transcript", to_field="text"), + ], + task="tasks.speech_recognition", + templates=[ + "templates.speech_recognition.default", + ], + __tags__={ + "license": "cc0-1.0", + "language": "en", + "task_categories": ["automatic-speech-recognition"], + "size_categories": ["1K Date: Sun, 31 Aug 2025 11:44:44 +0300 Subject: [PATCH 36/36] change to small ASR datasets --- .../speech_recognition_multilingual.py | 34 +------------------ .../speech_recognition_multilingual.json | 25 ++++++++++++++ 2 files changed, 26 insertions(+), 33 deletions(-) create mode 100644 src/unitxt/catalog/benchmarks/speech_recognition_multilingual.json diff --git a/prepare/benchmarks/speech_recognition_multilingual.py b/prepare/benchmarks/speech_recognition_multilingual.py index 477f22886b..e0857e6839 100644 --- a/prepare/benchmarks/speech_recognition_multilingual.py +++ b/prepare/benchmarks/speech_recognition_multilingual.py @@ -4,38 +4,6 @@ benchmark = Benchmark( subsets={ - "voxpopuli": DatasetRecipe( - card="cards.esb.voxpopuli", - format="formats.chat_api", - ), - "ami": DatasetRecipe( - card="cards.esb.ami", - format="formats.chat_api", - ), - "librispeech": DatasetRecipe( - card="cards.esb.librispeech", - format="formats.chat_api", - ), - "spgispeech": DatasetRecipe( - card="cards.esb.spgispeech", - format="formats.chat_api", - ), - "tedlium": DatasetRecipe( - card="cards.esb.tedlium", - format="formats.chat_api", - ), - "earnings22": DatasetRecipe( - card="cards.esb.earnings22", - format="formats.chat_api", - ), - "commonvoice": DatasetRecipe( - card="cards.esb.commonvoice", - format="formats.chat_api", - ), - "commonvoice_en": DatasetRecipe( - card="cards.commonvoice.en", - format="formats.chat_api", - ), "commonvoice_de": DatasetRecipe( card="cards.commonvoice.de", format="formats.chat_api", @@ -55,4 +23,4 @@ }, ) -add_to_catalog(benchmark, "benchmarks.speech_recognition", overwrite=True) +add_to_catalog(benchmark, "benchmarks.speech_recognition_multilingual", overwrite=True) diff --git a/src/unitxt/catalog/benchmarks/speech_recognition_multilingual.json b/src/unitxt/catalog/benchmarks/speech_recognition_multilingual.json new file mode 100644 index 0000000000..22d01d9d6a --- /dev/null +++ b/src/unitxt/catalog/benchmarks/speech_recognition_multilingual.json @@ -0,0 +1,25 @@ +{ + "__type__": "benchmark", + "subsets": { + "commonvoice_de": { + "__type__": "dataset_recipe", + "card": "cards.commonvoice.de", + "format": "formats.chat_api" + }, + "commonvoice_es": { + "__type__": "dataset_recipe", + "card": "cards.commonvoice.es", + "format": "formats.chat_api" + }, + "commonvoice_fr": { + "__type__": "dataset_recipe", + "card": "cards.commonvoice.fr", + "format": "formats.chat_api" + }, + "commonvoice_pt": { + "__type__": "dataset_recipe", + "card": "cards.commonvoice.pt", + "format": "formats.chat_api" + } + } +}