Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 55 additions & 16 deletions src/eval_framework/tasks/benchmarks/wmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any

import pycountry
import sacrebleu

from eval_framework.metrics.completion.bleu import LINEWISE_BLEU
from eval_framework.metrics.completion.chrf import LINEWISE_CHRF
Expand All @@ -12,7 +11,10 @@


class WMT(BaseTask[str], ABC):
"""WMT dataset:"""
"""WMT dataset: https://huggingface.co/datasets/wmt

Uses HuggingFace datasets for deterministic data loading.
"""

NAME = "WMT"
DATASET_PATH = ""
Expand All @@ -27,13 +29,50 @@ def __init__(self, num_fewshot: int = 0) -> None:
self.stop_sequences: list[str] = [".\n", " phrase: ", "phrase:", "phrase: ", " phrase:", "\n\n"]

def _load_dataset(self, subject: str | None) -> None:
src_file, ref_file, _, _, _ = sacrebleu.download_test_set(test_set=self.DATASET_PATH, langpair=subject)
src_data, ref_data = [[line.rstrip() for line in sacrebleu.smart_open(file)] for file in (src_file, ref_file)]

data_list = [{"source": src, "target": ref, "subject": subject} for src, ref in zip(src_data, ref_data)]
"""Load WMT dataset from HuggingFace.

The HuggingFace WMT datasets use the structure:
{
"translation": {
"src_lang": "Source text",
"tgt_lang": "Target text"
}
}

Args:
subject: Language pair in format "src-tgt" (e.g., "de-en", "en-fr")
"""
# Load from HuggingFace - subject is the language pair like "de-en"
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=subject)

# Parse the language pair to get source and target language codes
if subject is None:
raise ValueError("subject is required for WMT dataset")
src_lang, tgt_lang = subject.split("-")

self.dataset = {}
self.rnd = random.Random(RANDOM_SEED)
self.rnd.shuffle(data_list)
self.dataset = {"test": data_list}

for split, data in hf_dataset.items():
if split not in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
continue

# Transform HuggingFace format to expected format
Copy link
Collaborator

@jordisassoon jordisassoon Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain the expected format? specifically with regards to the shuffling

data_list = []
for item in data:
translation = item["translation"]
data_list.append(
{
"source": translation[src_lang],
"target": translation[tgt_lang],
"subject": subject,
}
)

if split == self.SAMPLE_SPLIT:
self.rnd.shuffle(data_list)

self.dataset[split] = data_list

def _code_to_language(self, code: str) -> str:
# key is alpha_2 or alpha_3 depending on the code length
Expand Down Expand Up @@ -69,8 +108,8 @@ def post_process_generated_completion(self, completion_text: str, sample: Sample

class WMT14(WMT):
NAME = "WMT14"
DATASET_PATH = "wmt14"
SUBJECTS = ["en-fr", "fr-en"]
DATASET_PATH = "wmt/wmt14"
SUBJECTS = ["fr-en", "en-fr"]
LANGUAGE = {
"en-fr": (Language["ENG"], Language["FRA"]),
"fr-en": (Language["FRA"], Language["ENG"]),
Expand All @@ -79,7 +118,7 @@ class WMT14(WMT):

class WMT16(WMT):
NAME = "WMT16"
DATASET_PATH = "wmt16"
DATASET_PATH = "wmt/wmt16"
SUBJECTS = ["de-en", "en-de"]
LANGUAGE = {
"de-en": (Language["DEU"], Language["ENG"]),
Expand All @@ -89,7 +128,7 @@ class WMT16(WMT):

class WMT20(WMT):
NAME = "WMT20"
DATASET_PATH = "wmt20"
DATASET_PATH = "wmt/wmt20"
SUBJECTS = ["de-en", "de-fr", "en-de", "fr-de"]
LANGUAGE = {
"de-en": (Language["DEU"], Language["ENG"]),
Expand Down Expand Up @@ -130,8 +169,8 @@ def post_process_generated_completion(self, completion_text: str, sample: Sample

class WMT14_INSTRUCT(WMT_INSTRUCT):
NAME = "WMT14 Instruct"
DATASET_PATH = "wmt14"
SUBJECTS = ["en-fr", "fr-en"]
DATASET_PATH = "wmt/wmt14"
SUBJECTS = ["fr-en", "en-fr"]
LANGUAGE = {
"en-fr": (Language["ENG"], Language["FRA"]),
"fr-en": (Language["FRA"], Language["ENG"]),
Expand All @@ -140,7 +179,7 @@ class WMT14_INSTRUCT(WMT_INSTRUCT):

class WMT16_INSTRUCT(WMT_INSTRUCT):
NAME = "WMT16 Instruct"
DATASET_PATH = "wmt16"
DATASET_PATH = "wmt/wmt16"
SUBJECTS = ["de-en", "en-de"]
LANGUAGE = {
"de-en": (Language["DEU"], Language["ENG"]),
Expand All @@ -150,7 +189,7 @@ class WMT16_INSTRUCT(WMT_INSTRUCT):

class WMT20_INSTRUCT(WMT_INSTRUCT):
NAME = "WMT20 Instruct"
DATASET_PATH = "wmt20"
DATASET_PATH = "wmt/wmt20"
SUBJECTS = ["de-en", "de-fr", "en-de", "fr-de"]
LANGUAGE = {
"de-en": (Language["DEU"], Language["ENG"]),
Expand Down
34 changes: 0 additions & 34 deletions src/eval_framework/tasks/task_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,6 @@ def make_sure_all_hf_datasets_are_in_cache(only_datasets: set[str] | None = None
time.sleep(random.randint(1, 5))
logger.info(f"Processed {task_name}")

# Sacrebleu uses its own cache (SACREBLEU env var), separate from HF datasets.
# We cache them together to ensure all evaluation data is available.
_ensure_sacrebleu_datasets_cached()


def update_changed_datasets_only(verbose: bool = True) -> tuple[bool, set[str]]:
"""
Expand All @@ -239,9 +235,6 @@ def update_changed_datasets_only(verbose: bool = True) -> tuple[bool, set[str]]:
all_up_to_date, datasets_to_update = get_datasets_needing_update()

if all_up_to_date:
# Even when HF datasets are current, ensure sacrebleu is cached
# (it has its own cache and isn't tracked by dataset_commits.json)
_ensure_sacrebleu_datasets_cached()
print("Nothing to update!")
return False, set()

Expand Down Expand Up @@ -293,32 +286,5 @@ def save_hf_dataset_commits() -> None:
print(f"Saved {len(commits)} dataset commits to {cache_file}")


def _ensure_sacrebleu_datasets_cached() -> None:
"""Pre-download sacrebleu WMT datasets to ensure they're cached.

Sacrebleu uses its own cache (controlled by SACREBLEU env var).
This ensures WMT test sets are downloaded and cached alongside HF datasets.
"""
import sacrebleu

# WMT datasets used by the framework (from wmt.py)
WMT_DATASETS = {
"wmt14": ["en-fr", "fr-en"],
"wmt16": ["de-en", "en-de"],
"wmt20": ["de-en", "de-fr", "en-de", "fr-de"],
}

print("Ensuring sacrebleu WMT datasets are cached...")
for test_set, langpairs in WMT_DATASETS.items():
for langpair in langpairs:
try:
sacrebleu.download_test_set(test_set=test_set, langpair=langpair)
print(f" {test_set}/{langpair}: OK")
except Exception as e:
print(f" {test_set}/{langpair}: FAILED ({e})")

print("Sacrebleu datasets cached!")


if __name__ == "__main__":
print(list(registered_tasks_iter()))
4 changes: 0 additions & 4 deletions tests/tests_eval_framework/tasks/test_all_formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,6 @@ def test_all_tasks_formatter(task_name: str, formatter_cls: type[BaseFormatter])
Raises:
AssertionError: If the hash of the formatter output does not match the expected value
"""
# Skip WMT tasks - sacrebleu file loading has non-determinism
if "WMT" in task_name:
pytest.skip(f"Skipping {task_name}: WMT tasks use sacrebleu with non-deterministic file loading")

task_class = get_task(task_name)
args = SPECIAL_ARGS.get(task_class.__name__, {"num_fewshot": 1})

Expand Down
146 changes: 146 additions & 0 deletions tests/tests_eval_framework/tasks/test_wmt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
Tests for WMT benchmark tasks using HuggingFace datasets.

Validates that WMT tasks load data deterministically from HuggingFace datasets
instead of sacrebleu file-based loading.
"""

import pytest

from eval_framework.tasks.benchmarks.wmt import (
WMT14,
WMT14_INSTRUCT,
WMT16,
WMT16_INSTRUCT,
WMT20,
WMT20_INSTRUCT,
)


class TestWMTDatasetStructure:
"""Test that WMT tasks load data with correct structure from HuggingFace."""

@pytest.mark.parametrize(
"task_cls,subject",
[
(WMT14, "fr-en"),
(WMT14, "en-fr"),
(WMT16, "de-en"),
(WMT16, "en-de"),
(WMT20, "de-en"),
(WMT20, "de-fr"),
],
)
def test_load_dataset_structure(self, task_cls: type, subject: str) -> None:
"""Test that dataset loads with correct structure."""
task = task_cls(num_fewshot=0)
task._load_dataset(subject)

assert "test" in task.dataset
assert len(task.dataset["test"]) > 0

# Check item structure
item = task.dataset["test"][0]
assert "source" in item
assert "target" in item
assert "subject" in item
assert item["subject"] == subject
assert isinstance(item["source"], str)
assert isinstance(item["target"], str)
assert len(item["source"]) > 0
assert len(item["target"]) > 0

@pytest.mark.parametrize("task_cls", [WMT14, WMT16, WMT20])
def test_deterministic_loading(self, task_cls: type) -> None:
"""Test that loading is deterministic across multiple runs."""
subject = task_cls.SUBJECTS[0]

# Load twice
task1 = task_cls(num_fewshot=0)
task1._load_dataset(subject)

task2 = task_cls(num_fewshot=0)
task2._load_dataset(subject)

# Verify identical ordering after shuffle
assert len(task1.dataset["test"]) == len(task2.dataset["test"])
for i in range(min(10, len(task1.dataset["test"]))):
assert task1.dataset["test"][i]["source"] == task2.dataset["test"][i]["source"]
assert task1.dataset["test"][i]["target"] == task2.dataset["test"][i]["target"]


class TestWMTSampleGeneration:
"""Test WMT sample generation."""

@pytest.mark.parametrize("task_cls", [WMT14, WMT16, WMT20])
def test_sample_generation(self, task_cls: type) -> None:
"""Test that samples can be generated correctly."""
task = task_cls(num_fewshot=0)
samples = list(task.iterate_samples(num_samples=3))

assert len(samples) == 3
for sample in samples:
assert sample.messages is not None
assert len(sample.messages) > 0
assert sample.ground_truth is not None

@pytest.mark.parametrize("task_cls", [WMT14, WMT16, WMT20])
def test_sample_with_fewshot(self, task_cls: type) -> None:
"""Test that few-shot samples are generated correctly."""
task = task_cls(num_fewshot=1)
samples = list(task.iterate_samples(num_samples=2))

assert len(samples) == 2
for sample in samples:
# With fewshot, we should have more messages
assert len(sample.messages) >= 2


class TestWMTInstructVariants:
"""Test WMT instruct variants."""

@pytest.mark.parametrize(
"task_cls,subject",
[
(WMT14_INSTRUCT, "fr-en"),
(WMT16_INSTRUCT, "de-en"),
(WMT20_INSTRUCT, "de-en"),
],
)
def test_instruct_sample_generation(self, task_cls: type, subject: str) -> None:
"""Test that instruct variants generate samples correctly."""
task = task_cls(num_fewshot=0)
samples = list(task.iterate_samples(num_samples=2))

assert len(samples) == 2
for sample in samples:
assert sample.messages is not None
# Check that the instruction format contains "translate"
first_message_content = sample.messages[0].content
assert "translate" in first_message_content.lower()


class TestWMTPostProcessing:
"""Test WMT post-processing methods."""

def test_post_process_with_stop_sequence(self) -> None:
"""Test that stop sequences are handled correctly."""
task = WMT16(num_fewshot=0)

# Test various stop sequences
text_with_stop = "Hello world.\nThis should be cut"
result = task.post_process_generated_completion(text_with_stop)
assert result == "Hello world"

text_with_phrase = "Hello world phrase: extra"
result = task.post_process_generated_completion(text_with_phrase)
assert result == "Hello world"

def test_instruct_post_process(self) -> None:
"""Test instruct variant post-processing."""
task = WMT16_INSTRUCT(num_fewshot=0)

# Test prefix removal
text_with_prefix = "This is the translation: Hello world"
result = task.post_process_generated_completion(text_with_prefix)
assert result == "Hello world"
Loading