Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyrit/datasets/executors/question_answer/wmdp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def fetch_wmdp_dataset(category: Optional[str] = None) -> QuestionAnsweringDatas
"""
# Determine which subset of data to load
data_categories = None
if not category: # if category is not specified, read in all 3 subsets of data
if category is None: # if category is not specified, read in all 3 subsets of data
data_categories = ["wmdp-cyber", "wmdp-bio", "wmdp-chem"]
elif category not in ["cyber", "bio", "chem"]:
raise ValueError(f"Invalid Parameter: {category}. Expected 'cyber', 'bio', or 'chem'")
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/datasets/test_wmdp_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from unittest.mock import patch

import pytest

from pyrit.datasets.executors.question_answer.wmdp_dataset import fetch_wmdp_dataset


class _EmptySplit:
def __len__(self) -> int:
return 0


def test_fetch_wmdp_dataset_rejects_empty_category():
with patch(
"pyrit.datasets.executors.question_answer.wmdp_dataset.load_dataset",
return_value={"test": _EmptySplit()},
) as mock_load_dataset:
with pytest.raises(ValueError, match="Invalid Parameter"):
fetch_wmdp_dataset(category="")

mock_load_dataset.assert_not_called()
Loading