From 1e0c50bbf2dd71bbde229b109a40586f7e7ecbca Mon Sep 17 00:00:00 2001 From: biefan <70761325+biefan@users.noreply.github.com> Date: Tue, 17 Mar 2026 06:44:22 +0000 Subject: [PATCH] Reject empty WMDP category values --- .../executors/question_answer/wmdp_dataset.py | 2 +- tests/unit/datasets/test_wmdp_dataset.py | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 tests/unit/datasets/test_wmdp_dataset.py diff --git a/pyrit/datasets/executors/question_answer/wmdp_dataset.py b/pyrit/datasets/executors/question_answer/wmdp_dataset.py index f77e0872f..1270c9b6c 100644 --- a/pyrit/datasets/executors/question_answer/wmdp_dataset.py +++ b/pyrit/datasets/executors/question_answer/wmdp_dataset.py @@ -31,7 +31,7 @@ def fetch_wmdp_dataset(category: Optional[str] = None) -> QuestionAnsweringDatas """ # Determine which subset of data to load data_categories = None - if not category: # if category is not specified, read in all 3 subsets of data + if category is None: # if category is not specified, read in all 3 subsets of data data_categories = ["wmdp-cyber", "wmdp-bio", "wmdp-chem"] elif category not in ["cyber", "bio", "chem"]: raise ValueError(f"Invalid Parameter: {category}. Expected 'cyber', 'bio', or 'chem'") diff --git a/tests/unit/datasets/test_wmdp_dataset.py b/tests/unit/datasets/test_wmdp_dataset.py new file mode 100644 index 000000000..d00c1a4fd --- /dev/null +++ b/tests/unit/datasets/test_wmdp_dataset.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from unittest.mock import patch + +import pytest + +from pyrit.datasets.executors.question_answer.wmdp_dataset import fetch_wmdp_dataset + + +class _EmptySplit: + def __len__(self) -> int: + return 0 + + +def test_fetch_wmdp_dataset_rejects_empty_category(): + with patch( + "pyrit.datasets.executors.question_answer.wmdp_dataset.load_dataset", + return_value={"test": _EmptySplit()}, + ) as mock_load_dataset: + with pytest.raises(ValueError, match="Invalid Parameter"): + fetch_wmdp_dataset(category="") + + mock_load_dataset.assert_not_called()