-
Notifications
You must be signed in to change notification settings - Fork 4
fix(wmt): use HuggingFace datasets instead of sacrebleu #137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
AhmedHammam-AA
wants to merge
3
commits into
main
Choose a base branch
from
fix/wmt-huggingface-datasets
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| """ | ||
| Tests for WMT benchmark tasks using HuggingFace datasets. | ||
|
|
||
| Validates that WMT tasks load data deterministically from HuggingFace datasets | ||
| instead of sacrebleu file-based loading. | ||
| """ | ||
|
|
||
| import pytest | ||
|
|
||
| from eval_framework.tasks.benchmarks.wmt import ( | ||
| WMT14, | ||
| WMT14_INSTRUCT, | ||
| WMT16, | ||
| WMT16_INSTRUCT, | ||
| WMT20, | ||
| WMT20_INSTRUCT, | ||
| ) | ||
|
|
||
|
|
||
| class TestWMTDatasetStructure: | ||
| """Test that WMT tasks load data with correct structure from HuggingFace.""" | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "task_cls,subject", | ||
| [ | ||
| (WMT14, "fr-en"), | ||
| (WMT14, "en-fr"), | ||
| (WMT16, "de-en"), | ||
| (WMT16, "en-de"), | ||
| (WMT20, "de-en"), | ||
| (WMT20, "de-fr"), | ||
| ], | ||
| ) | ||
| def test_load_dataset_structure(self, task_cls: type, subject: str) -> None: | ||
| """Test that dataset loads with correct structure.""" | ||
| task = task_cls(num_fewshot=0) | ||
| task._load_dataset(subject) | ||
|
|
||
| assert "test" in task.dataset | ||
| assert len(task.dataset["test"]) > 0 | ||
|
|
||
| # Check item structure | ||
| item = task.dataset["test"][0] | ||
| assert "source" in item | ||
| assert "target" in item | ||
| assert "subject" in item | ||
| assert item["subject"] == subject | ||
| assert isinstance(item["source"], str) | ||
| assert isinstance(item["target"], str) | ||
| assert len(item["source"]) > 0 | ||
| assert len(item["target"]) > 0 | ||
|
|
||
| @pytest.mark.parametrize("task_cls", [WMT14, WMT16, WMT20]) | ||
| def test_deterministic_loading(self, task_cls: type) -> None: | ||
| """Test that loading is deterministic across multiple runs.""" | ||
| subject = task_cls.SUBJECTS[0] | ||
|
|
||
| # Load twice | ||
| task1 = task_cls(num_fewshot=0) | ||
| task1._load_dataset(subject) | ||
|
|
||
| task2 = task_cls(num_fewshot=0) | ||
| task2._load_dataset(subject) | ||
|
|
||
| # Verify identical ordering after shuffle | ||
| assert len(task1.dataset["test"]) == len(task2.dataset["test"]) | ||
| for i in range(min(10, len(task1.dataset["test"]))): | ||
| assert task1.dataset["test"][i]["source"] == task2.dataset["test"][i]["source"] | ||
| assert task1.dataset["test"][i]["target"] == task2.dataset["test"][i]["target"] | ||
|
|
||
|
|
||
| class TestWMTSampleGeneration: | ||
| """Test WMT sample generation.""" | ||
|
|
||
| @pytest.mark.parametrize("task_cls", [WMT14, WMT16, WMT20]) | ||
| def test_sample_generation(self, task_cls: type) -> None: | ||
| """Test that samples can be generated correctly.""" | ||
| task = task_cls(num_fewshot=0) | ||
| samples = list(task.iterate_samples(num_samples=3)) | ||
|
|
||
| assert len(samples) == 3 | ||
| for sample in samples: | ||
| assert sample.messages is not None | ||
| assert len(sample.messages) > 0 | ||
| assert sample.ground_truth is not None | ||
|
|
||
| @pytest.mark.parametrize("task_cls", [WMT14, WMT16, WMT20]) | ||
| def test_sample_with_fewshot(self, task_cls: type) -> None: | ||
| """Test that few-shot samples are generated correctly.""" | ||
| task = task_cls(num_fewshot=1) | ||
| samples = list(task.iterate_samples(num_samples=2)) | ||
|
|
||
| assert len(samples) == 2 | ||
| for sample in samples: | ||
| # With fewshot, we should have more messages | ||
| assert len(sample.messages) >= 2 | ||
|
|
||
|
|
||
| class TestWMTInstructVariants: | ||
| """Test WMT instruct variants.""" | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "task_cls,subject", | ||
| [ | ||
| (WMT14_INSTRUCT, "fr-en"), | ||
| (WMT16_INSTRUCT, "de-en"), | ||
| (WMT20_INSTRUCT, "de-en"), | ||
| ], | ||
| ) | ||
| def test_instruct_sample_generation(self, task_cls: type, subject: str) -> None: | ||
| """Test that instruct variants generate samples correctly.""" | ||
| task = task_cls(num_fewshot=0) | ||
| samples = list(task.iterate_samples(num_samples=2)) | ||
|
|
||
| assert len(samples) == 2 | ||
| for sample in samples: | ||
| assert sample.messages is not None | ||
| # Check that the instruction format contains "translate" | ||
| first_message_content = sample.messages[0].content | ||
| assert "translate" in first_message_content.lower() | ||
|
|
||
|
|
||
| class TestWMTPostProcessing: | ||
| """Test WMT post-processing methods.""" | ||
|
|
||
| def test_post_process_with_stop_sequence(self) -> None: | ||
| """Test that stop sequences are handled correctly.""" | ||
| task = WMT16(num_fewshot=0) | ||
|
|
||
| # Test various stop sequences | ||
| text_with_stop = "Hello world.\nThis should be cut" | ||
| result = task.post_process_generated_completion(text_with_stop) | ||
| assert result == "Hello world" | ||
|
|
||
| text_with_phrase = "Hello world phrase: extra" | ||
| result = task.post_process_generated_completion(text_with_phrase) | ||
| assert result == "Hello world" | ||
|
|
||
| def test_instruct_post_process(self) -> None: | ||
| """Test instruct variant post-processing.""" | ||
| task = WMT16_INSTRUCT(num_fewshot=0) | ||
|
|
||
| # Test prefix removal | ||
| text_with_prefix = "This is the translation: Hello world" | ||
| result = task.post_process_generated_completion(text_with_prefix) | ||
| assert result == "Hello world" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you explain the expected format? specifically with regards to the shuffling