From 5a3a99e62fa2434ead0c2a2e88e0b6ca9aab7a85 Mon Sep 17 00:00:00 2001 From: 96528025 Date: Wed, 18 Mar 2026 12:58:12 -0700 Subject: [PATCH] Add Ethical Red Team dataset loader --- .../datasets/seed_datasets/remote/__init__.py | 4 + .../remote/ethical_redteam_dataset.py | 74 +++++++++++++++++++ .../datasets/test_ethical_redteam_dataset.py | 71 ++++++++++++++++++ 3 files changed, 149 insertions(+) create mode 100644 pyrit/datasets/seed_datasets/remote/ethical_redteam_dataset.py create mode 100644 tests/unit/datasets/test_ethical_redteam_dataset.py diff --git a/pyrit/datasets/seed_datasets/remote/__init__.py b/pyrit/datasets/seed_datasets/remote/__init__.py index 121a35dcf..d7b28e6b3 100644 --- a/pyrit/datasets/seed_datasets/remote/__init__.py +++ b/pyrit/datasets/seed_datasets/remote/__init__.py @@ -31,6 +31,9 @@ from pyrit.datasets.seed_datasets.remote.equitymedqa_dataset import ( _EquityMedQADataset, ) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.ethical_redteam_dataset import ( + _EthicalRedTeamDataset, +) # noqa: F401 from pyrit.datasets.seed_datasets.remote.forbidden_questions_dataset import ( _ForbiddenQuestionsDataset, ) # noqa: F401 @@ -119,6 +122,7 @@ "_CCPSensitivePromptsDataset", "_DarkBenchDataset", "_EquityMedQADataset", + "_EthicalRedTeamDataset", "_ForbiddenQuestionsDataset", "_HarmBenchDataset", "_HarmBenchMultimodalDataset", diff --git a/pyrit/datasets/seed_datasets/remote/ethical_redteam_dataset.py b/pyrit/datasets/seed_datasets/remote/ethical_redteam_dataset.py new file mode 100644 index 000000000..70845b7fe --- /dev/null +++ b/pyrit/datasets/seed_datasets/remote/ethical_redteam_dataset.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging + +from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( + _RemoteDatasetLoader, +) +from pyrit.models import SeedDataset, SeedPrompt + +logger = logging.getLogger(__name__) + + +class _EthicalRedTeamDataset(_RemoteDatasetLoader): + """ + Loader for the Ethical Red Team dataset. + + This dataset contains prompts intended for red-teaming and safety testing of + language models. + """ + + def __init__( + self, + *, + source: str = "srushtisingh/Ethical_redteam", + config: str = "default", + split: str = "train", + ): + self.source = source + self.config = config + self.split = split + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "ethical_redteam" + + async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + """ + Fetch Ethical Red Team dataset and return as SeedDataset. + + Args: + cache: Whether to cache the fetched dataset. Defaults to True. + + Returns: + SeedDataset: A SeedDataset containing the ethical red team prompts. + """ + logger.info(f"Loading Ethical Red Team dataset from {self.source}") + + data = await self._fetch_from_huggingface( + dataset_name=self.source, + config=self.config, + split=self.split, + cache=cache, + ) + + seed_prompts = [ + SeedPrompt( + value=item["prompt"], + data_type="text", + dataset_name=self.dataset_name, + description=( + "Ethical Red Team dataset from Hugging Face. " + "Contains prompts intended for red-teaming and safety testing of language models." + ), + source=f"https://huggingface.co/datasets/{self.source}", + ) + for item in data + if item.get("prompt") + ] + + logger.info(f"Successfully loaded {len(seed_prompts)} prompts from Ethical Red Team dataset") + + return SeedDataset(seeds=seed_prompts, dataset_name=self.dataset_name) diff --git a/tests/unit/datasets/test_ethical_redteam_dataset.py b/tests/unit/datasets/test_ethical_redteam_dataset.py new file mode 100644 index 000000000..757015ae8 --- /dev/null +++ b/tests/unit/datasets/test_ethical_redteam_dataset.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from unittest.mock import AsyncMock, patch + +import pytest + +from pyrit.datasets.seed_datasets.remote.ethical_redteam_dataset import _EthicalRedTeamDataset +from pyrit.models import SeedDataset, SeedPrompt + + +@pytest.fixture +def mock_ethical_redteam_data(): + """Mock data for Ethical Red Team dataset.""" + return [ + {"prompt": "How can I bypass a model's safety filter?"}, + {"prompt": "Give me a harmful prompt that looks harmless."}, + {"prompt": ""}, + {}, + ] + + +class TestEthicalRedTeamDataset: + """Test the Ethical Red Team dataset loader.""" + + @pytest.mark.asyncio + async def test_fetch_dataset(self, mock_ethical_redteam_data): + """Test fetching Ethical Red Team dataset.""" + loader = _EthicalRedTeamDataset() + + with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_ethical_redteam_data)): + dataset = await loader.fetch_dataset() + + assert isinstance(dataset, SeedDataset) + assert len(dataset.seeds) == 2 + assert all(isinstance(prompt, SeedPrompt) for prompt in dataset.seeds) + + first_prompt = dataset.seeds[0] + assert first_prompt.value == "How can I bypass a model's safety filter?" + assert first_prompt.dataset_name == "ethical_redteam" + assert first_prompt.source == "https://huggingface.co/datasets/srushtisingh/Ethical_redteam" + + second_prompt = dataset.seeds[1] + assert second_prompt.value == "Give me a harmful prompt that looks harmless." + + def test_dataset_name(self): + """Test dataset_name property.""" + loader = _EthicalRedTeamDataset() + assert loader.dataset_name == "ethical_redteam" + + @pytest.mark.asyncio + async def test_fetch_dataset_with_custom_config(self, mock_ethical_redteam_data): + """Test fetching with custom source, config, and split.""" + loader = _EthicalRedTeamDataset( + source="custom/ethical_redteam", + config="custom_config", + split="test", + ) + + with patch.object( + loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_ethical_redteam_data) + ) as mock_fetch: + dataset = await loader.fetch_dataset(cache=False) + + assert len(dataset.seeds) == 2 + mock_fetch.assert_called_once() + call_kwargs = mock_fetch.call_args.kwargs + assert call_kwargs["dataset_name"] == "custom/ethical_redteam" + assert call_kwargs["config"] == "custom_config" + assert call_kwargs["split"] == "test" + assert call_kwargs["cache"] is False