From 14096a816660cf9feebb06dee1f2ae0efe546162 Mon Sep 17 00:00:00 2001 From: ahmedhammam Date: Fri, 19 Dec 2025 14:40:18 +0000 Subject: [PATCH 1/3] fix(wmt): use HuggingFace datasets instead of sacrebleu --- src/eval_framework/tasks/benchmarks/wmt.py | 67 ++++++-- src/eval_framework/tasks/task_names.py | 33 ---- .../tasks/test_all_formatters.py | 4 - tests/tests_eval_framework/tasks/test_wmt.py | 146 ++++++++++++++++++ 4 files changed, 197 insertions(+), 53 deletions(-) create mode 100644 tests/tests_eval_framework/tasks/test_wmt.py diff --git a/src/eval_framework/tasks/benchmarks/wmt.py b/src/eval_framework/tasks/benchmarks/wmt.py index c7a11734..e973710e 100644 --- a/src/eval_framework/tasks/benchmarks/wmt.py +++ b/src/eval_framework/tasks/benchmarks/wmt.py @@ -3,7 +3,6 @@ from typing import Any import pycountry -import sacrebleu from eval_framework.metrics.completion.bleu import LINEWISE_BLEU from eval_framework.metrics.completion.chrf import LINEWISE_CHRF @@ -12,7 +11,10 @@ class WMT(BaseTask[str], ABC): - """WMT dataset:""" + """WMT dataset: https://huggingface.co/datasets/wmt + + Uses HuggingFace datasets for deterministic data loading. + """ NAME = "WMT" DATASET_PATH = "" @@ -27,13 +29,46 @@ def __init__(self, num_fewshot: int = 0) -> None: self.stop_sequences: list[str] = [".\n", " phrase: ", "phrase:", "phrase: ", " phrase:", "\n\n"] def _load_dataset(self, subject: str | None) -> None: - src_file, ref_file, _, _, _ = sacrebleu.download_test_set(test_set=self.DATASET_PATH, langpair=subject) - src_data, ref_data = [[line.rstrip() for line in sacrebleu.smart_open(file)] for file in (src_file, ref_file)] - - data_list = [{"source": src, "target": ref, "subject": subject} for src, ref in zip(src_data, ref_data)] + """Load WMT dataset from HuggingFace. + + The HuggingFace WMT datasets use the structure: + { + "translation": { + "src_lang": "Source text", + "tgt_lang": "Target text" + } + } + + Args: + subject: Language pair in format "src-tgt" (e.g., "de-en", "en-fr") + """ + # Load from HuggingFace - subject is the language pair like "de-en" + hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=subject) + + # Parse the language pair to get source and target language codes + src_lang, tgt_lang = subject.split("-") + + self.dataset = {} self.rnd = random.Random(RANDOM_SEED) - self.rnd.shuffle(data_list) - self.dataset = {"test": data_list} + + for split, data in hf_dataset.items(): + if split not in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]: + continue + + # Transform HuggingFace format to expected format + data_list = [] + for item in data: + translation = item["translation"] + data_list.append({ + "source": translation[src_lang], + "target": translation[tgt_lang], + "subject": subject, + }) + + if split == self.SAMPLE_SPLIT: + self.rnd.shuffle(data_list) + + self.dataset[split] = data_list def _code_to_language(self, code: str) -> str: # key is alpha_2 or alpha_3 depending on the code length @@ -69,8 +104,8 @@ def post_process_generated_completion(self, completion_text: str, sample: Sample class WMT14(WMT): NAME = "WMT14" - DATASET_PATH = "wmt14" - SUBJECTS = ["en-fr", "fr-en"] + DATASET_PATH = "wmt/wmt14" + SUBJECTS = ["fr-en", "en-fr"] LANGUAGE = { "en-fr": (Language["ENG"], Language["FRA"]), "fr-en": (Language["FRA"], Language["ENG"]), @@ -79,7 +114,7 @@ class WMT14(WMT): class WMT16(WMT): NAME = "WMT16" - DATASET_PATH = "wmt16" + DATASET_PATH = "wmt/wmt16" SUBJECTS = ["de-en", "en-de"] LANGUAGE = { "de-en": (Language["DEU"], Language["ENG"]), @@ -89,7 +124,7 @@ class WMT16(WMT): class WMT20(WMT): NAME = "WMT20" - DATASET_PATH = "wmt20" + DATASET_PATH = "wmt/wmt20" SUBJECTS = ["de-en", "de-fr", "en-de", "fr-de"] LANGUAGE = { "de-en": (Language["DEU"], Language["ENG"]), @@ -130,8 +165,8 @@ def post_process_generated_completion(self, completion_text: str, sample: Sample class WMT14_INSTRUCT(WMT_INSTRUCT): NAME = "WMT14 Instruct" - DATASET_PATH = "wmt14" - SUBJECTS = ["en-fr", "fr-en"] + DATASET_PATH = "wmt/wmt14" + SUBJECTS = ["fr-en", "en-fr"] LANGUAGE = { "en-fr": (Language["ENG"], Language["FRA"]), "fr-en": (Language["FRA"], Language["ENG"]), @@ -140,7 +175,7 @@ class WMT14_INSTRUCT(WMT_INSTRUCT): class WMT16_INSTRUCT(WMT_INSTRUCT): NAME = "WMT16 Instruct" - DATASET_PATH = "wmt16" + DATASET_PATH = "wmt/wmt16" SUBJECTS = ["de-en", "en-de"] LANGUAGE = { "de-en": (Language["DEU"], Language["ENG"]), @@ -150,7 +185,7 @@ class WMT16_INSTRUCT(WMT_INSTRUCT): class WMT20_INSTRUCT(WMT_INSTRUCT): NAME = "WMT20 Instruct" - DATASET_PATH = "wmt20" + DATASET_PATH = "wmt/wmt20" SUBJECTS = ["de-en", "de-fr", "en-de", "fr-de"] LANGUAGE = { "de-en": (Language["DEU"], Language["ENG"]), diff --git a/src/eval_framework/tasks/task_names.py b/src/eval_framework/tasks/task_names.py index 749fdb86..01ce2d19 100644 --- a/src/eval_framework/tasks/task_names.py +++ b/src/eval_framework/tasks/task_names.py @@ -221,9 +221,6 @@ def make_sure_all_hf_datasets_are_in_cache(only_datasets: set[str] | None = None time.sleep(random.randint(1, 5)) logger.info(f"Processed {task_name}") - # Sacrebleu uses its own cache (SACREBLEU env var), separate from HF datasets. - # We cache them together to ensure all evaluation data is available. - _ensure_sacrebleu_datasets_cached() def update_changed_datasets_only(verbose: bool = True) -> tuple[bool, set[str]]: @@ -239,9 +236,6 @@ def update_changed_datasets_only(verbose: bool = True) -> tuple[bool, set[str]]: all_up_to_date, datasets_to_update = get_datasets_needing_update() if all_up_to_date: - # Even when HF datasets are current, ensure sacrebleu is cached - # (it has its own cache and isn't tracked by dataset_commits.json) - _ensure_sacrebleu_datasets_cached() print("Nothing to update!") return False, set() @@ -293,32 +287,5 @@ def save_hf_dataset_commits() -> None: print(f"Saved {len(commits)} dataset commits to {cache_file}") -def _ensure_sacrebleu_datasets_cached() -> None: - """Pre-download sacrebleu WMT datasets to ensure they're cached. - - Sacrebleu uses its own cache (controlled by SACREBLEU env var). - This ensures WMT test sets are downloaded and cached alongside HF datasets. - """ - import sacrebleu - - # WMT datasets used by the framework (from wmt.py) - WMT_DATASETS = { - "wmt14": ["en-fr", "fr-en"], - "wmt16": ["de-en", "en-de"], - "wmt20": ["de-en", "de-fr", "en-de", "fr-de"], - } - - print("Ensuring sacrebleu WMT datasets are cached...") - for test_set, langpairs in WMT_DATASETS.items(): - for langpair in langpairs: - try: - sacrebleu.download_test_set(test_set=test_set, langpair=langpair) - print(f" {test_set}/{langpair}: OK") - except Exception as e: - print(f" {test_set}/{langpair}: FAILED ({e})") - - print("Sacrebleu datasets cached!") - - if __name__ == "__main__": print(list(registered_tasks_iter())) diff --git a/tests/tests_eval_framework/tasks/test_all_formatters.py b/tests/tests_eval_framework/tasks/test_all_formatters.py index 33073dcd..9732ebd4 100644 --- a/tests/tests_eval_framework/tasks/test_all_formatters.py +++ b/tests/tests_eval_framework/tasks/test_all_formatters.py @@ -201,10 +201,6 @@ def test_all_tasks_formatter(task_name: str, formatter_cls: type[BaseFormatter]) Raises: AssertionError: If the hash of the formatter output does not match the expected value """ - # Skip WMT tasks - sacrebleu file loading has non-determinism - if "WMT" in task_name: - pytest.skip(f"Skipping {task_name}: WMT tasks use sacrebleu with non-deterministic file loading") - task_class = get_task(task_name) args = SPECIAL_ARGS.get(task_class.__name__, {"num_fewshot": 1}) diff --git a/tests/tests_eval_framework/tasks/test_wmt.py b/tests/tests_eval_framework/tasks/test_wmt.py new file mode 100644 index 00000000..1d55d3ce --- /dev/null +++ b/tests/tests_eval_framework/tasks/test_wmt.py @@ -0,0 +1,146 @@ +""" +Tests for WMT benchmark tasks using HuggingFace datasets. + +Validates that WMT tasks load data deterministically from HuggingFace datasets +instead of sacrebleu file-based loading. +""" + +import pytest + +from eval_framework.tasks.benchmarks.wmt import ( + WMT14, + WMT14_INSTRUCT, + WMT16, + WMT16_INSTRUCT, + WMT20, + WMT20_INSTRUCT, +) + + +class TestWMTDatasetStructure: + """Test that WMT tasks load data with correct structure from HuggingFace.""" + + @pytest.mark.parametrize( + "task_cls,subject", + [ + (WMT14, "fr-en"), + (WMT14, "en-fr"), + (WMT16, "de-en"), + (WMT16, "en-de"), + (WMT20, "de-en"), + (WMT20, "de-fr"), + ], + ) + def test_load_dataset_structure(self, task_cls: type, subject: str) -> None: + """Test that dataset loads with correct structure.""" + task = task_cls(num_fewshot=0) + task._load_dataset(subject) + + assert "test" in task.dataset + assert len(task.dataset["test"]) > 0 + + # Check item structure + item = task.dataset["test"][0] + assert "source" in item + assert "target" in item + assert "subject" in item + assert item["subject"] == subject + assert isinstance(item["source"], str) + assert isinstance(item["target"], str) + assert len(item["source"]) > 0 + assert len(item["target"]) > 0 + + @pytest.mark.parametrize("task_cls", [WMT14, WMT16, WMT20]) + def test_deterministic_loading(self, task_cls: type) -> None: + """Test that loading is deterministic across multiple runs.""" + subject = task_cls.SUBJECTS[0] + + # Load twice + task1 = task_cls(num_fewshot=0) + task1._load_dataset(subject) + + task2 = task_cls(num_fewshot=0) + task2._load_dataset(subject) + + # Verify identical ordering after shuffle + assert len(task1.dataset["test"]) == len(task2.dataset["test"]) + for i in range(min(10, len(task1.dataset["test"]))): + assert task1.dataset["test"][i]["source"] == task2.dataset["test"][i]["source"] + assert task1.dataset["test"][i]["target"] == task2.dataset["test"][i]["target"] + + +class TestWMTSampleGeneration: + """Test WMT sample generation.""" + + @pytest.mark.parametrize("task_cls", [WMT14, WMT16, WMT20]) + def test_sample_generation(self, task_cls: type) -> None: + """Test that samples can be generated correctly.""" + task = task_cls(num_fewshot=0) + samples = list(task.iterate_samples(num_samples=3)) + + assert len(samples) == 3 + for sample in samples: + assert sample.messages is not None + assert len(sample.messages) > 0 + assert sample.ground_truth is not None + + @pytest.mark.parametrize("task_cls", [WMT14, WMT16, WMT20]) + def test_sample_with_fewshot(self, task_cls: type) -> None: + """Test that few-shot samples are generated correctly.""" + task = task_cls(num_fewshot=1) + samples = list(task.iterate_samples(num_samples=2)) + + assert len(samples) == 2 + for sample in samples: + # With fewshot, we should have more messages + assert len(sample.messages) >= 2 + + +class TestWMTInstructVariants: + """Test WMT instruct variants.""" + + @pytest.mark.parametrize( + "task_cls,subject", + [ + (WMT14_INSTRUCT, "fr-en"), + (WMT16_INSTRUCT, "de-en"), + (WMT20_INSTRUCT, "de-en"), + ], + ) + def test_instruct_sample_generation(self, task_cls: type, subject: str) -> None: + """Test that instruct variants generate samples correctly.""" + task = task_cls(num_fewshot=0) + samples = list(task.iterate_samples(num_samples=2)) + + assert len(samples) == 2 + for sample in samples: + assert sample.messages is not None + # Check that the instruction format contains "translate" + first_message_content = sample.messages[0].content + assert "translate" in first_message_content.lower() + + +class TestWMTPostProcessing: + """Test WMT post-processing methods.""" + + def test_post_process_with_stop_sequence(self) -> None: + """Test that stop sequences are handled correctly.""" + task = WMT16(num_fewshot=0) + + # Test various stop sequences + text_with_stop = "Hello world.\nThis should be cut" + result = task.post_process_generated_completion(text_with_stop) + assert result == "Hello world" + + text_with_phrase = "Hello world phrase: extra" + result = task.post_process_generated_completion(text_with_phrase) + assert result == "Hello world" + + def test_instruct_post_process(self) -> None: + """Test instruct variant post-processing.""" + task = WMT16_INSTRUCT(num_fewshot=0) + + # Test prefix removal + text_with_prefix = "This is the translation: Hello world" + result = task.post_process_generated_completion(text_with_prefix) + assert result == "Hello world" From d09a25c8eb5dd1b099c3d6b828f956016dceacc4 Mon Sep 17 00:00:00 2001 From: ahmedhammam Date: Fri, 19 Dec 2025 14:43:16 +0000 Subject: [PATCH 2/3] fix lint --- src/eval_framework/tasks/benchmarks/wmt.py | 12 +++++++----- src/eval_framework/tasks/task_names.py | 1 - tests/tests_eval_framework/tasks/test_wmt.py | 4 ++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/eval_framework/tasks/benchmarks/wmt.py b/src/eval_framework/tasks/benchmarks/wmt.py index e973710e..dbcd0415 100644 --- a/src/eval_framework/tasks/benchmarks/wmt.py +++ b/src/eval_framework/tasks/benchmarks/wmt.py @@ -59,11 +59,13 @@ def _load_dataset(self, subject: str | None) -> None: data_list = [] for item in data: translation = item["translation"] - data_list.append({ - "source": translation[src_lang], - "target": translation[tgt_lang], - "subject": subject, - }) + data_list.append( + { + "source": translation[src_lang], + "target": translation[tgt_lang], + "subject": subject, + } + ) if split == self.SAMPLE_SPLIT: self.rnd.shuffle(data_list) diff --git a/src/eval_framework/tasks/task_names.py b/src/eval_framework/tasks/task_names.py index 01ce2d19..a2e0ee82 100644 --- a/src/eval_framework/tasks/task_names.py +++ b/src/eval_framework/tasks/task_names.py @@ -222,7 +222,6 @@ def make_sure_all_hf_datasets_are_in_cache(only_datasets: set[str] | None = None logger.info(f"Processed {task_name}") - def update_changed_datasets_only(verbose: bool = True) -> tuple[bool, set[str]]: """ Check for updates and download only changed datasets. diff --git a/tests/tests_eval_framework/tasks/test_wmt.py b/tests/tests_eval_framework/tasks/test_wmt.py index 1d55d3ce..cb788cd5 100644 --- a/tests/tests_eval_framework/tasks/test_wmt.py +++ b/tests/tests_eval_framework/tasks/test_wmt.py @@ -126,7 +126,7 @@ class TestWMTPostProcessing: def test_post_process_with_stop_sequence(self) -> None: """Test that stop sequences are handled correctly.""" task = WMT16(num_fewshot=0) - + # Test various stop sequences text_with_stop = "Hello world.\nThis should be cut" result = task.post_process_generated_completion(text_with_stop) @@ -139,7 +139,7 @@ def test_post_process_with_stop_sequence(self) -> None: def test_instruct_post_process(self) -> None: """Test instruct variant post-processing.""" task = WMT16_INSTRUCT(num_fewshot=0) - + # Test prefix removal text_with_prefix = "This is the translation: Hello world" result = task.post_process_generated_completion(text_with_prefix) From 54b56735a18d5b3d6a46b933e34cc0d9e4b5cdea Mon Sep 17 00:00:00 2001 From: ahmedhammam Date: Fri, 19 Dec 2025 15:27:10 +0000 Subject: [PATCH 3/3] fix mypy --- src/eval_framework/tasks/benchmarks/wmt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/eval_framework/tasks/benchmarks/wmt.py b/src/eval_framework/tasks/benchmarks/wmt.py index dbcd0415..e27b61a4 100644 --- a/src/eval_framework/tasks/benchmarks/wmt.py +++ b/src/eval_framework/tasks/benchmarks/wmt.py @@ -46,6 +46,8 @@ def _load_dataset(self, subject: str | None) -> None: hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=subject) # Parse the language pair to get source and target language codes + if subject is None: + raise ValueError("subject is required for WMT dataset") src_lang, tgt_lang = subject.split("-") self.dataset = {}