diff --git a/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py b/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py index 21910d594..ac6012122 100644 --- a/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py @@ -72,8 +72,8 @@ def __init__( self, *, api_key: Optional[str] = None, - severity: Optional[PromptIntelSeverity] = None, - categories: Optional[list[PromptIntelCategory]] = None, + severity: Optional[PromptIntelSeverity | str] = None, + categories: Optional[list[PromptIntelCategory | str]] = None, search: Optional[str] = None, max_prompts: Optional[int] = None, ) -> None: @@ -93,6 +93,8 @@ def __init__( ValueError: If an invalid severity or category is provided. """ self._api_key = api_key + normalized_severity: Optional[PromptIntelSeverity] = None + normalized_categories: Optional[list[PromptIntelCategory]] = None if severity is not None: valid_severities = {s.value for s in PromptIntelSeverity} @@ -101,20 +103,28 @@ def __init__( raise ValueError( f"Invalid severity: {sev_value}. Valid values: {[s.value for s in PromptIntelSeverity]}" ) + normalized_severity = ( + severity if isinstance(severity, PromptIntelSeverity) else PromptIntelSeverity(sev_value) + ) if categories is not None: valid_categories = {c.value for c in PromptIntelCategory} + category_values = [cat.value if isinstance(cat, PromptIntelCategory) else cat for cat in categories] invalid_categories = { - cat.value if isinstance(cat, PromptIntelCategory) else cat for cat in categories - } - valid_categories + category_value for category_value in category_values if category_value not in valid_categories + } if invalid_categories: raise ValueError( f"Invalid categories: {', '.join(str(c) for c in invalid_categories)}. " f"Valid values: {[c.value for c in PromptIntelCategory]}" ) + normalized_categories = [ + category if isinstance(category, PromptIntelCategory) else PromptIntelCategory(category) + for category in categories + ] - self._severity = severity - self._categories = categories + self._severity = normalized_severity + self._categories = normalized_categories self._search = search self._max_prompts = max_prompts self.source = "https://promptintel.novahunting.ai" diff --git a/tests/unit/datasets/test_promptintel_dataset.py b/tests/unit/datasets/test_promptintel_dataset.py index bea3377e3..d30e24e2a 100644 --- a/tests/unit/datasets/test_promptintel_dataset.py +++ b/tests/unit/datasets/test_promptintel_dataset.py @@ -347,6 +347,17 @@ async def test_severity_filter_passed_to_api(self, api_key, mock_promptintel_res call_kwargs = mock_get.call_args assert call_kwargs.kwargs["params"]["severity"] == "critical" + @pytest.mark.asyncio + async def test_string_severity_filter_passed_to_api(self, api_key, mock_promptintel_response): + loader = _PromptIntelDataset(api_key=api_key, severity="critical") + mock_resp = _make_mock_response(json_data=mock_promptintel_response) + + with patch("requests.get", return_value=mock_resp) as mock_get: + await loader.fetch_dataset() + + call_kwargs = mock_get.call_args + assert call_kwargs.kwargs["params"]["severity"] == "critical" + @pytest.mark.asyncio async def test_category_filter_passed_to_api(self, api_key, mock_promptintel_response): loader = _PromptIntelDataset(api_key=api_key, categories=[PromptIntelCategory.MANIPULATION]) @@ -358,6 +369,17 @@ async def test_category_filter_passed_to_api(self, api_key, mock_promptintel_res call_kwargs = mock_get.call_args assert call_kwargs.kwargs["params"]["category"] == "manipulation" + @pytest.mark.asyncio + async def test_string_category_filter_passed_to_api(self, api_key, mock_promptintel_response): + loader = _PromptIntelDataset(api_key=api_key, categories=["manipulation"]) + mock_resp = _make_mock_response(json_data=mock_promptintel_response) + + with patch("requests.get", return_value=mock_resp) as mock_get: + await loader.fetch_dataset() + + call_kwargs = mock_get.call_args + assert call_kwargs.kwargs["params"]["category"] == "manipulation" + @pytest.mark.asyncio async def test_multiple_categories_make_separate_api_calls(self, api_key): manipulation_response = {