From 857f596e178b99308c366f39ead31f0810edc077 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Tue, 10 Mar 2026 23:47:44 +0000 Subject: [PATCH 01/29] scaffolding --- .../remote/aegis_ai_content_safety_dataset.py | 13 +++++- .../seed_datasets/seed_dataset_provider.py | 41 +++++++++++++++---- pyrit/datasets/seed_datasets/seed_metadata.py | 33 +++++++++++++++ 3 files changed, 78 insertions(+), 9 deletions(-) create mode 100644 pyrit/datasets/seed_datasets/seed_metadata.py diff --git a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py index 4b9004f772..ecc952f35a 100644 --- a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py @@ -11,6 +11,8 @@ ) from pyrit.models import SeedDataset, SeedPrompt +from pyrit.datasets.seed_datasets.seed_metadata import SeedMetadata + logger = logging.getLogger(__name__) @@ -107,7 +109,8 @@ def __init__( # Validate harm categories if provided if harm_categories: - invalid_categories = {cat for cat in harm_categories if cat not in self.HARM_CATEGORIES} + invalid_categories = { + cat for cat in harm_categories if cat not in self.HARM_CATEGORIES} if invalid_categories: raise ValueError( f"Invalid harm categories: {invalid_categories}. Valid categories are: {self.HARM_CATEGORIES}" @@ -157,7 +160,8 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: prompt_harm_categories = [] if violated_categories: # The violated_categories field contains comma-separated category names - categories = [cat.strip() for cat in violated_categories.split(",") if cat.strip()] + categories = [ + cat.strip() for cat in violated_categories.split(",") if cat.strip()] prompt_harm_categories = categories # Filter by harm_categories if specified @@ -186,3 +190,8 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: ) return SeedDataset(seeds=seed_prompts, dataset_name=self.dataset_name) + + def metadata_factory(self) -> SeedMetadata: + return SeedMetadata( + size= + ) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 56b61b3996..cb0e7ed11a 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -10,6 +10,7 @@ from tqdm import tqdm from pyrit.models.seeds import SeedDataset +from pyrit.datasets.seed_datasets.seed_metadata import SeedMetadata logger = logging.getLogger(__name__) @@ -51,6 +52,12 @@ def dataset_name(self) -> str: str: The dataset name (e.g., "HarmBench", "JailbreakBench JBB-Behaviors") """ + @abstractmethod + def metadata_factory(self) -> SeedMetadata: + """ + Build metadata from tags and derived fields (e.g. dataset size). + """ + @abstractmethod async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: """ @@ -78,10 +85,13 @@ def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: return cls._registry.copy() @classmethod - def get_all_dataset_names(cls) -> list[str]: + def get_all_dataset_names(cls, filters: Optional[dict[str, str]] = None) -> list[str]: """ Get the names of all registered datasets. + Args: + filters (Optional[Dict[str, str]]): List of filters to apply. + Returns: List[str]: List of dataset names from all registered providers. @@ -97,9 +107,21 @@ def get_all_dataset_names(cls) -> list[str]: try: # Instantiate to get dataset name provider = provider_class() + + # Injection point for filtering. TODO + + # 1 Remove invalid filters by checking ground truth in seed_metadata + + # 2 Remove invalid filter values by invoking helpers (e.g. size: <100 is fine, size: foobar is not) + + # 3 Only execute the following line if the filter key is valid and so is the value, AND the dataset meets the condition + + # Problem: We don't know size at this point because we're just collecting the name. Size and source are tricky for remote datasets + # since we can't check them statically dataset_names.add(provider.dataset_name) except Exception as e: - raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e + raise ValueError( + f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) @classmethod @@ -142,9 +164,11 @@ async def fetch_datasets_async( # Validate dataset names if specified if dataset_names is not None: available_names = cls.get_all_dataset_names() - invalid_names = [name for name in dataset_names if name not in available_names] + invalid_names = [ + name for name in dataset_names if name not in available_names] if invalid_names: - raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") + raise ValueError( + f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -170,7 +194,8 @@ async def fetch_single_dataset( # Progress tracking total_count = len(cls._registry) - pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") + pbar = tqdm(total=total_count, + desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -208,10 +233,12 @@ async def fetch_with_semaphore( logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] - combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) + combined_seeds = list( + existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset - logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") + logger.info( + f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values()) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py new file mode 100644 index 0000000000..8ac0c99fd5 --- /dev/null +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from enum import Enum +from dataclasses import dataclass + + +class DatasetLoadingRank(Enum): + """Represents the general difficulty of loading in a dataset.""" + DEFAULT = "default" + EXTENDED = "extended" + SLOW = "slow" + + +class DatasetModalities(Enum): + TEXT = "text" + IMAGE = "image" + VIDEO = "video" + AUDIO = "audio" + + +class DatasetSourceType(Enum): + GENERIC_URL = "generic_url" + LOCAL = "local" + HUGGING_FACE = "hugging_face" + + +@dataclass +class DatasetMetadata: + size: int + modalities: list[DatasetModalities] + source: DatasetSourceType + loading_rank: DatasetLoadingRank From 15b58e8a47fd9d673b6e58c0e7f5b01e24d52e9d Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 11 Mar 2026 19:52:05 +0000 Subject: [PATCH 02/29] more scaffolding --- pyrit/datasets/__init__.py | 3 ++ .../seed_datasets/seed_dataset_provider.py | 43 +++++++++++++------ pyrit/datasets/seed_datasets/seed_metadata.py | 32 +++++++++++++- .../test_seed_dataset_provider_integration.py | 15 +++++-- .../datasets/test_seed_dataset_metadata.py | 32 ++++++++++++++ 5 files changed, 109 insertions(+), 16 deletions(-) create mode 100644 tests/unit/datasets/test_seed_dataset_metadata.py diff --git a/pyrit/datasets/__init__.py b/pyrit/datasets/__init__.py index 5eb89b6f44..c8d8592625 100644 --- a/pyrit/datasets/__init__.py +++ b/pyrit/datasets/__init__.py @@ -8,8 +8,11 @@ from pyrit.datasets.jailbreak.text_jailbreak import TextJailBreak from pyrit.datasets.seed_datasets import local, remote # noqa: F401 from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider +from pyrit.datasets.seed_datasets.seed_metadata import DatasetMetadata, DatasetFilters __all__ = [ + "DatasetMetadata", + "DatasetFilters", "SeedDatasetProvider", "TextJailBreak", ] diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index cb0e7ed11a..7e11bf5f4c 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -26,6 +26,10 @@ class SeedDatasetProvider(ABC): Subclasses must implement: - fetch_dataset(): Fetch and return the dataset as a SeedDataset - dataset_name property: Human-readable name for the dataset + + All subclasses also have a _metadata property that is optional to make + dataset addition easier, but failing to complete it makes downstream + analysis more difficult. """ _registry: dict[str, type["SeedDatasetProvider"]] = {} @@ -41,6 +45,10 @@ def __init_subclass__(cls, **kwargs: Any) -> None: if not inspect.isabstract(cls) and getattr(cls, "should_register", True): SeedDatasetProvider._registry[cls.__name__] = cls logger.debug(f"Registered dataset provider: {cls.__name__}") + # Providing metadata is optional + if getattr(cls, "_metadata", False): + logger.debug( + f"Dataset provider {cls.__name__} provided metadata.") @property @abstractmethod @@ -52,12 +60,6 @@ def dataset_name(self) -> str: str: The dataset name (e.g., "HarmBench", "JailbreakBench JBB-Behaviors") """ - @abstractmethod - def metadata_factory(self) -> SeedMetadata: - """ - Build metadata from tags and derived fields (e.g. dataset size). - """ - @abstractmethod async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: """ @@ -103,21 +105,38 @@ def get_all_dataset_names(cls, filters: Optional[dict[str, str]] = None) -> list >>> print(f"Available datasets: {', '.join(names)}") """ dataset_names = set() + # 1 Remove invalid filters by checking ground truth in seed_metadata + if filters: + valid_filters = [f.value for f in SeedMetadata.DatasetFilters] + # Prefer doing this to a list or set comprehension so we can raise ValueError on + # specific unsupported filters + for filter, _ in filters.items(): + if filter not in valid_filters: + raise ValueError( + f"Tried to pass invalid filter `{filter}` to SeedDatasetProvider.get_all_dataset_names!") + for provider_class in cls._registry.values(): try: # Instantiate to get dataset name provider = provider_class() - # Injection point for filtering. TODO + if filters: + # 1 Check if it has metadata + # should this be none or false + if getattr(provider, "_metadata", False): + # Skip a dataset without metadata if we have filters enabled + continue + + # 2 Remove invalid filter values by invoking helpers (e.g. size: <100 is fine, size: foobar is not) - # 1 Remove invalid filters by checking ground truth in seed_metadata + # 3 Only execute the following line if the filter key is valid and so is the value, AND the dataset meets the condition - # 2 Remove invalid filter values by invoking helpers (e.g. size: <100 is fine, size: foobar is not) + # Problem: We don't know size at this point because we're just collecting the name. Size and source are tricky for remote datasets + # since we can't check them statically - # 3 Only execute the following line if the filter key is valid and so is the value, AND the dataset meets the condition + # Solution: If filter is dynamic, then just download or load into central memory early to retrieve it + # and present a warning to the user that this is occuring - # Problem: We don't know size at this point because we're just collecting the name. Size and source are tricky for remote datasets - # since we can't check them statically dataset_names.add(provider.dataset_name) except Exception as e: raise ValueError( diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 8ac0c99fd5..8ad37940d4 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -4,6 +4,20 @@ from enum import Enum from dataclasses import dataclass +""" +TODO Finish docstring + +Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). + +We have one DatasetMetadata dataclass that is our ground truth. As we instantiate datasets +using the subclass call in SeedDatasetProvider, we create DatasetMetadata and assign it to +a private variable there. + +Some fields are dynamic (e.g. loading statistics, timestamp, dataset size) and are left as +NoneType until the SeedDatasetProvider actually downloads/parses the dataset and puts it in +CentralMemory. +""" + class DatasetLoadingRank(Enum): """Represents the general difficulty of loading in a dataset.""" @@ -27,7 +41,23 @@ class DatasetSourceType(Enum): @dataclass class DatasetMetadata: + # TODO: separate dynamic fields from static fields and mark dynamic fields as None size: int modalities: list[DatasetModalities] source: DatasetSourceType - loading_rank: DatasetLoadingRank + rank: DatasetLoadingRank + + +class DatasetFilters(Enum): + # TODO: This is a bad way of extracting the fields from DatasetMetadata. + # A metaclass or even just calling getattr might be better. + SIZE = "size" + MODALITIES = "modalities" + SOURCE = "source" + RANK = "rank" + +# TODO These stubs should be moved somewhere, maybe as static methods to the metadata dataclass? + + +def _validate_filter_value(v): + """Check if the filter value given is valid.""" diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index a3ede4beab..ceacc2a860 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -37,10 +37,12 @@ async def test_fetch_dataset_integration(self, name, provider_cls): try: # Use max_examples for slow providers that fetch many remote images - provider = provider_cls(max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls() + provider = provider_cls( + max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls() dataset = await provider.fetch_dataset(cache=False) - assert isinstance(dataset, SeedDataset), f"{name} did not return a SeedDataset" + assert isinstance( + dataset, SeedDataset), f"{name} did not return a SeedDataset" assert len(dataset.seeds) > 0, f"{name} returned an empty dataset" assert dataset.dataset_name, f"{name} has no dataset_name" @@ -51,7 +53,14 @@ async def test_fetch_dataset_integration(self, name, provider_cls): f"Seed dataset_name mismatch in {name}: {seed.dataset_name} != {dataset.dataset_name}" ) - logger.info(f"Successfully verified {name} with {len(dataset.seeds)} seeds") + logger.info( + f"Successfully verified {name} with {len(dataset.seeds)} seeds") except Exception as e: pytest.fail(f"Failed to fetch dataset from {name}: {str(e)}") + + @pytest.mark.asyncio + @pytest.mark.parameterize("name,provider_cls", get_dataset_providers()) + async def test_fetch_dataset_with_filtering(self, name, provider_cls): + # TODO + pass diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py new file mode 100644 index 0000000000..7f38572311 --- /dev/null +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +TODO + +Tests for SeedDatasetMetadata +""" + + +class TestMetadataParsing: + def test_invalid_filter_key(self): + pass + + def test_invalid_filter_value(self): + pass + + +class TestMetadataLifecycle: + def test_static_values_populated(self): + pass + + def test_dynamic_values_populated(self): + pass + + +class TestMetadataPerformance: + def test_quick_retrieval_for_static_values(self): + pass + + def test_acceptable_retrieval_for_dynamic_values(self): + pass From fc43c8c6f7e198ce9669f7cd8dd7047795791977 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 12 Mar 2026 00:32:21 +0000 Subject: [PATCH 03/29] . --- pyrit/datasets/seed_datasets/seed_metadata.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 8ad37940d4..b5a4070f89 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -61,3 +61,13 @@ class DatasetFilters(Enum): def _validate_filter_value(v): """Check if the filter value given is valid.""" + + +def _metadata_builder(): + """ + Force build metadata for all datasets. + Download/load into local memory. + Add a timestamp. + Add all derived attributes. + Make sure every dataset subclass has it. + """ From 9f357e64178f3311332136fceb93b3c568d43296 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 12 Mar 2026 20:48:39 +0000 Subject: [PATCH 04/29] data types --- pyrit/datasets/__init__.py | 17 ++- .../remote/aegis_ai_content_safety_dataset.py | 13 +-- .../seed_datasets/remote/harmbench_dataset.py | 16 +++ .../seed_datasets/seed_dataset_provider.py | 108 +++++++++++------- pyrit/datasets/seed_datasets/seed_metadata.py | 83 +++++++++----- 5 files changed, 153 insertions(+), 84 deletions(-) diff --git a/pyrit/datasets/__init__.py b/pyrit/datasets/__init__.py index c8d8592625..8f4b543238 100644 --- a/pyrit/datasets/__init__.py +++ b/pyrit/datasets/__init__.py @@ -8,11 +8,22 @@ from pyrit.datasets.jailbreak.text_jailbreak import TextJailBreak from pyrit.datasets.seed_datasets import local, remote # noqa: F401 from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider -from pyrit.datasets.seed_datasets.seed_metadata import DatasetMetadata, DatasetFilters +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetFilter, + SeedDatasetLoadingRank, + SeedDatasetMetadata, + SeedDatasetModality, + SeedDatasetSize, + SeedDatasetSourceType, +) __all__ = [ - "DatasetMetadata", - "DatasetFilters", + "SeedDatasetFilter", + "SeedDatasetMetadata", + "SeedDatasetLoadingRank", + "SeedDatasetModality", + "SeedDatasetSize", + "SeedDatasetSourceType", "SeedDatasetProvider", "TextJailBreak", ] diff --git a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py index ecc952f35a..4b9004f772 100644 --- a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py @@ -11,8 +11,6 @@ ) from pyrit.models import SeedDataset, SeedPrompt -from pyrit.datasets.seed_datasets.seed_metadata import SeedMetadata - logger = logging.getLogger(__name__) @@ -109,8 +107,7 @@ def __init__( # Validate harm categories if provided if harm_categories: - invalid_categories = { - cat for cat in harm_categories if cat not in self.HARM_CATEGORIES} + invalid_categories = {cat for cat in harm_categories if cat not in self.HARM_CATEGORIES} if invalid_categories: raise ValueError( f"Invalid harm categories: {invalid_categories}. Valid categories are: {self.HARM_CATEGORIES}" @@ -160,8 +157,7 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: prompt_harm_categories = [] if violated_categories: # The violated_categories field contains comma-separated category names - categories = [ - cat.strip() for cat in violated_categories.split(",") if cat.strip()] + categories = [cat.strip() for cat in violated_categories.split(",") if cat.strip()] prompt_harm_categories = categories # Filter by harm_categories if specified @@ -190,8 +186,3 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: ) return SeedDataset(seeds=seed_prompts, dataset_name=self.dataset_name) - - def metadata_factory(self) -> SeedMetadata: - return SeedMetadata( - size= - ) diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py index fc6d46e54d..a31ca1cf58 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py @@ -6,6 +6,13 @@ from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetLoadingRank, + SeedDatasetMetadata, + SeedDatasetModality, + SeedDatasetSize, + SeedDatasetSourceType, +) from pyrit.models import SeedDataset, SeedObjective @@ -19,6 +26,15 @@ class _HarmBenchDataset(_RemoteDatasetLoader): Reference: https://github.com/centerforaisafety/HarmBench """ + _metadata = SeedDatasetMetadata( + tags={"default, safety"}, + size=SeedDatasetSize.LARGE, + modalities=[SeedDatasetModality.TEXT], + source=SeedDatasetSourceType.GENERIC_URL, + rank=SeedDatasetLoadingRank.DEFAULT, + harm_categories=["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"], + ) + def __init__( self, *, diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 7e11bf5f4c..cb7ea6fed7 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -9,8 +9,8 @@ from tqdm import tqdm +from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetFilter, SeedDatasetMetadata from pyrit.models.seeds import SeedDataset -from pyrit.datasets.seed_datasets.seed_metadata import SeedMetadata logger = logging.getLogger(__name__) @@ -45,10 +45,9 @@ def __init_subclass__(cls, **kwargs: Any) -> None: if not inspect.isabstract(cls) and getattr(cls, "should_register", True): SeedDatasetProvider._registry[cls.__name__] = cls logger.debug(f"Registered dataset provider: {cls.__name__}") - # Providing metadata is optional - if getattr(cls, "_metadata", False): - logger.debug( - f"Dataset provider {cls.__name__} provided metadata.") + # Providing metadata is optional. + if getattr(cls, "_metadata", True): + logger.debug(f"Dataset provider {cls.__name__} provided metadata.") @property @abstractmethod @@ -87,12 +86,12 @@ def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: return cls._registry.copy() @classmethod - def get_all_dataset_names(cls, filters: Optional[dict[str, str]] = None) -> list[str]: + def get_all_dataset_names(cls, filters: Optional[SeedDatasetFilter] = None) -> list[str]: """ Get the names of all registered datasets. Args: - filters (Optional[Dict[str, str]]): List of filters to apply. + filters (Optional[SeedDatasetFilter]): List of filters to apply. Returns: List[str]: List of dataset names from all registered providers. @@ -105,44 +104,72 @@ def get_all_dataset_names(cls, filters: Optional[dict[str, str]] = None) -> list >>> print(f"Available datasets: {', '.join(names)}") """ dataset_names = set() - # 1 Remove invalid filters by checking ground truth in seed_metadata - if filters: - valid_filters = [f.value for f in SeedMetadata.DatasetFilters] - # Prefer doing this to a list or set comprehension so we can raise ValueError on - # specific unsupported filters - for filter, _ in filters.items(): - if filter not in valid_filters: - raise ValueError( - f"Tried to pass invalid filter `{filter}` to SeedDatasetProvider.get_all_dataset_names!") - for provider_class in cls._registry.values(): try: # Instantiate to get dataset name provider = provider_class() - if filters: - # 1 Check if it has metadata - # should this be none or false - if getattr(provider, "_metadata", False): - # Skip a dataset without metadata if we have filters enabled - continue - - # 2 Remove invalid filter values by invoking helpers (e.g. size: <100 is fine, size: foobar is not) - - # 3 Only execute the following line if the filter key is valid and so is the value, AND the dataset meets the condition + # Extract metadata, default to False if not found + metadata = getattr(provider, "_metadata", False) + if filters and not metadata: + continue - # Problem: We don't know size at this point because we're just collecting the name. Size and source are tricky for remote datasets - # since we can't check them statically + # Type safety for metadata object given getattr return type + if isinstance(metadata, bool): + raise ValueError - # Solution: If filter is dynamic, then just download or load into central memory early to retrieve it - # and present a warning to the user that this is occuring + # Filters detected but no match -> don't add this dataset + if filters and not cls._match_filter(metadata=metadata, filters=filters): + continue + # This triggers when filters match (and filters exist) dataset_names.add(provider.dataset_name) except Exception as e: - raise ValueError( - f"Could not get dataset name from {provider_class.__name__}: {e}") from e + raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) + @classmethod + def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter) -> bool: + """ + + Match the filter(s) with the metadata provided by the SeedDatasetProvider subclass. + By default, filters across dimensions (e.g. size, harm categories) are treated as AND + requirements. Filters within a dimension (e.g. SeedDatasetSize.SMALL, + SeedDatasetSize.LARGE) are treated as OR requirements. + + Args: + metadata (SeedDatasetMetadata): The metadata object extracted from the SeedDatasetProvider + subclass. + filters (SeedDatasetFilter): The filter object provided by the user to get_all_dataset_names. + + Returns: + bool: Whether or not the filters match or not. + """ + # Tags + if metadata.tags and "all" in metadata.tags: + # This is the only condition that returns true, because we want the "all" + # tag to override everything else in the filter. + return True + + # These lines all disable SIM103 because metadata and filters tags can be optional, so + # directly checking for membership breaks type checking. + if metadata.tags and filters.tags and not (filters.tags & metadata.tags): # noqa: SIM103 + return False + + # Size + if metadata.size and filters.sizes and metadata.size not in filters.sizes: # noqa: SIM103 + return False + + # Harm Categories + + # Source Type + + # Modalities + + # Rank + + return True + @classmethod async def fetch_datasets_async( cls, @@ -183,11 +210,9 @@ async def fetch_datasets_async( # Validate dataset names if specified if dataset_names is not None: available_names = cls.get_all_dataset_names() - invalid_names = [ - name for name in dataset_names if name not in available_names] + invalid_names = [name for name in dataset_names if name not in available_names] if invalid_names: - raise ValueError( - f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") + raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -213,8 +238,7 @@ async def fetch_single_dataset( # Progress tracking total_count = len(cls._registry) - pbar = tqdm(total=total_count, - desc="Loading datasets - this can take a few minutes", unit="dataset") + pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -252,12 +276,10 @@ async def fetch_with_semaphore( logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] - combined_seeds = list( - existing_dataset.seeds) + list(dataset.seeds) + combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset - logger.info( - f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") + logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values()) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index b5a4070f89..6b87f2fceb 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from enum import Enum from dataclasses import dataclass +from enum import Enum +from typing import Optional """ TODO Finish docstring @@ -19,55 +20,83 @@ """ -class DatasetLoadingRank(Enum): +class SeedDatasetSize(Enum): + """Ordinal size (by bucket) of the dataset.""" + + TINY = "tiny" # < 10 + SMALL = "small" # >= 10, < 100 + MEDIUM = "medium" # >= 100, < 500 + LARGE = "large" # >= 500, < 5000 + HUGE = "huge" # >= 5000 + + +class SeedDatasetLoadingRank(Enum): """Represents the general difficulty of loading in a dataset.""" + DEFAULT = "default" EXTENDED = "extended" SLOW = "slow" -class DatasetModalities(Enum): +class SeedDatasetModality(Enum): + """ + ... + """ + TEXT = "text" IMAGE = "image" VIDEO = "video" AUDIO = "audio" -class DatasetSourceType(Enum): +class SeedDatasetSourceType(Enum): + """ + ... + """ + GENERIC_URL = "generic_url" LOCAL = "local" HUGGING_FACE = "hugging_face" @dataclass -class DatasetMetadata: - # TODO: separate dynamic fields from static fields and mark dynamic fields as None - size: int - modalities: list[DatasetModalities] - source: DatasetSourceType - rank: DatasetLoadingRank - +class SeedDatasetFilter: + """ + ... + """ -class DatasetFilters(Enum): - # TODO: This is a bad way of extracting the fields from DatasetMetadata. - # A metaclass or even just calling getattr might be better. - SIZE = "size" - MODALITIES = "modalities" - SOURCE = "source" - RANK = "rank" + tags: Optional[set[str]] + sizes: Optional[list[SeedDatasetSize]] + modalities: Optional[list[SeedDatasetModality]] + sources: Optional[list[SeedDatasetSourceType]] + ranks: Optional[list[SeedDatasetLoadingRank]] + harm_categories: Optional[list[str]] -# TODO These stubs should be moved somewhere, maybe as static methods to the metadata dataclass? +@dataclass(frozen=True) +class SeedDatasetMetadata: + """ + ... + """ -def _validate_filter_value(v): - """Check if the filter value given is valid.""" + tags: Optional[set[str]] + size: Optional[SeedDatasetSize] + modalities: Optional[list[SeedDatasetModality]] + source: Optional[SeedDatasetSourceType] + rank: Optional[SeedDatasetLoadingRank] + harm_categories: Optional[list[str]] -def _metadata_builder(): +class SeedDatasetMetadataUtilities: """ - Force build metadata for all datasets. - Download/load into local memory. - Add a timestamp. - Add all derived attributes. - Make sure every dataset subclass has it. + Collected utilities for managing and updating SeedDatasetMetadata. """ + + @staticmethod + def populate_metadata() -> None: + """ + WARNING: Because this function updates the metadata for each SeedDatasetProvider, + it changes the provider's corresopnding source file. Run with caution! + + Update the metadata per SeedDatasetProvider. + """ From 34f8953f14f5ae9ef0e24faea0ae1db6d5e28ece Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 13 Mar 2026 19:30:35 +0000 Subject: [PATCH 05/29] redesign --- .../local/local_dataset_loader.py | 45 ++++++++-- .../seed_datasets/remote/harmbench_dataset.py | 21 ++--- .../remote/jbb_behaviors_dataset.py | 2 +- .../remote/remote_dataset_loader.py | 48 ++++++++--- .../seed_datasets/seed_dataset_provider.py | 44 ++++++---- pyrit/datasets/seed_datasets/seed_metadata.py | 47 +++++++---- .../test_seed_dataset_provider_integration.py | 2 +- .../datasets/test_seed_dataset_metadata.py | 83 ++++++++++++++++--- .../datasets/test_seed_dataset_provider.py | 75 +++++++++++++++-- 9 files changed, 283 insertions(+), 84 deletions(-) diff --git a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py index 270fba1568..a54062c779 100644 --- a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py @@ -2,12 +2,14 @@ # Licensed under the MIT license. import logging +import yaml from collections.abc import Callable from pathlib import Path from typing import Any from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider from pyrit.models import SeedDataset +from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetMetadata logger = logging.getLogger(__name__) @@ -36,7 +38,8 @@ def __init__(self, *, file_path: Path): dataset = SeedDataset.from_yaml_file(file_path) # Use the dataset_name from the YAML if available, otherwise use filename self._dataset_name = ( - getattr(dataset, "dataset_name", None) or getattr(dataset, "name", None) or file_path.stem + getattr(dataset, "dataset_name", None) or getattr( + dataset, "name", None) or file_path.stem ) except Exception as e: logger.warning(f"Could not pre-load dataset from {file_path}: {e}") @@ -67,9 +70,32 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: dataset.dataset_name = self.dataset_name return dataset except Exception as e: - logger.error(f"Failed to load local dataset from {self.file_path}: {e}") + logger.error( + f"Failed to load local dataset from {self.file_path}: {e}") raise + def _parse_metadata(self) -> SeedDatasetMetadata | None: + """ + Extract metadata from class attributes and format into SeedDatasetMetadata schema. + + Raises: + Exception: If the dataset cannot be loaded. + """ + valid_fields = [f.name for f in fields(SeedDatasetMetadata)] + try: + with open(self.file_path, 'r') as f: + dataset = yaml.safe_load(f) + except Exception as e: + logger.error( + f"Failed to load local datset from {self.file_path}: {e}" + ) + raise + self_metadata = {k: v for k, v in dataset if k in valid_fields} + if not self_metadata: + return None + return SeedDatasetMetadata(**self_metadata) + + def _register_local_datasets() -> None: """ @@ -93,21 +119,26 @@ def _register_local_datasets() -> None: def make_init(path: Path) -> Callable[[Any], None]: def __init__(self: Any) -> None: # noqa: N807 - super(self.__class__, self).__init__(file_path=path) + super(self.__class__, self).__init__( + file_path=path) return __init__ type( class_name, (_LocalDatasetLoader,), - {"__init__": make_init(yaml_file), "should_register": True, "__module__": __name__}, + {"__init__": make_init( + yaml_file), "should_register": True, "__module__": __name__}, ) - logger.debug(f"Registered local dataset loader: {class_name} for {yaml_file.name}") + logger.debug( + f"Registered local dataset loader: {class_name} for {yaml_file.name}") except Exception as e: - logger.warning(f"Failed to register local dataset {yaml_file}: {e}") + logger.warning( + f"Failed to register local dataset {yaml_file}: {e}") else: - logger.warning(f"Seed datasets directory not found: {seed_datasets_path}") + logger.warning( + f"Seed datasets directory not found: {seed_datasets_path}") # Execute registration diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py index a31ca1cf58..4759d4b7ee 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py @@ -26,14 +26,13 @@ class _HarmBenchDataset(_RemoteDatasetLoader): Reference: https://github.com/centerforaisafety/HarmBench """ - _metadata = SeedDatasetMetadata( - tags={"default, safety"}, - size=SeedDatasetSize.LARGE, - modalities=[SeedDatasetModality.TEXT], - source=SeedDatasetSourceType.GENERIC_URL, - rank=SeedDatasetLoadingRank.DEFAULT, - harm_categories=["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"], - ) + # Metadata + harm_categories: list[str] = ["cybercrime", "illegal", + "harmful", "chemical_biological", "harassment"] + modalities: list[SeedDatasetModality] = [SeedDatasetModality.TEXT] + size: SeedDatasetSize = SeedDatasetSize.LARGE # 504 seeds + # "default" means included in curated set + tags: set[str] = {"default", "safety"} def __init__( self, @@ -88,7 +87,8 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: # Check for missing keys in the example missing_keys = required_keys - example.keys() if missing_keys: - raise ValueError(f"Missing keys in example: {', '.join(missing_keys)}") + raise ValueError( + f"Missing keys in example: {', '.join(missing_keys)}") # Extract data category = example["SemanticCategory"] @@ -104,7 +104,8 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: "biological, illegal activities, etc." ), source="https://github.com/centerforaisafety/HarmBench", - authors=["Mantas Mazeika", "Long Phan", "Xuwang Yin", "Andy Zou", "Zifan Wang", "Norman Mu"], + authors=["Mantas Mazeika", "Long Phan", "Xuwang Yin", + "Andy Zou", "Zifan Wang", "Norman Mu"], ) seeds.append(seed_prompt) diff --git a/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py b/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py index a622a4a018..b2b45c2a33 100644 --- a/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py @@ -23,7 +23,7 @@ class _JBBBehaviorsDataset(_RemoteDatasetLoader): and may contain offensive content. Users should check with their legal department before using these prompts against production LLMs. """ - + def __init__( self, *, diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 5cd9212846..2f234f451b 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -8,6 +8,7 @@ import tempfile from abc import ABC from collections.abc import Callable +from dataclasses import fields from pathlib import Path from typing import Any, Literal, Optional, TextIO, cast @@ -19,6 +20,7 @@ from pyrit.common.path import DB_DATA_PATH from pyrit.common.text_helper import read_txt, write_txt from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider +from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetMetadata logger = logging.getLogger(__name__) @@ -74,7 +76,8 @@ def _validate_file_type(self, file_type: str) -> None: """ if file_type not in FILE_TYPE_HANDLERS: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") + raise ValueError( + f"Invalid file_type. Expected one of: {valid_types}.") def _read_cache(self, *, cache_file: Path, file_type: str) -> list[dict[str, str]]: """ @@ -131,15 +134,19 @@ def _fetch_from_public_url(self, *, source: str, file_type: str) -> list[dict[st if file_type in FILE_TYPE_HANDLERS: if file_type == "json": return cast( - "list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text)) + "list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"]( + io.StringIO(response.text)) ) return cast( "list[dict[str, str]]", - FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO("\n".join(response.text.splitlines()))), + FILE_TYPE_HANDLERS[file_type]["read"]( + io.StringIO("\n".join(response.text.splitlines()))), ) valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") - raise Exception(f"Failed to fetch examples from public URL. Status code: {response.status_code}") + raise ValueError( + f"Invalid file_type. Expected one of: {valid_types}.") + raise Exception( + f"Failed to fetch examples from public URL. Status code: {response.status_code}") def _fetch_from_file(self, *, source: str, file_type: str) -> list[dict[str, str]]: """ @@ -159,7 +166,8 @@ def _fetch_from_file(self, *, source: str, file_type: str) -> list[dict[str, str if file_type in FILE_TYPE_HANDLERS: return cast("list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](file)) valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") + raise ValueError( + f"Invalid file_type. Expected one of: {valid_types}.") def _fetch_from_url( self, @@ -191,21 +199,26 @@ def _fetch_from_url( file_type = source.split(".")[-1] if file_type not in FILE_TYPE_HANDLERS: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") + raise ValueError( + f"Invalid file_type. Expected one of: {valid_types}.") data_home = DB_DATA_PATH / "seed-prompt-entries" - cache_file = data_home / self._get_cache_file_name(source=source, file_type=file_type) + cache_file = data_home / \ + self._get_cache_file_name(source=source, file_type=file_type) if cache and cache_file.exists(): return self._read_cache(cache_file=cache_file, file_type=file_type) if source_type == "public_url": - examples = self._fetch_from_public_url(source=source, file_type=file_type) + examples = self._fetch_from_public_url( + source=source, file_type=file_type) elif source_type == "file": - examples = self._fetch_from_file(source=source, file_type=file_type) + examples = self._fetch_from_file( + source=source, file_type=file_type) if cache: - self._write_cache(cache_file=cache_file, examples=examples, file_type=file_type) + self._write_cache(cache_file=cache_file, + examples=examples, file_type=file_type) else: with tempfile.NamedTemporaryFile( delete=False, mode="w", suffix=f".{file_type}", encoding="utf-8" @@ -283,5 +296,16 @@ def _load_dataset_sync() -> Any: # Run the synchronous load_dataset in a thread pool to avoid blocking the event loop return await asyncio.to_thread(_load_dataset_sync) except Exception as e: - logger.error(f"Failed to load HuggingFace dataset {dataset_name}: {e}") + logger.error( + f"Failed to load HuggingFace dataset {dataset_name}: {e}") raise + + def _parse_metadata(self) -> SeedDatasetMetadata | None: + """ + Extract metadata from class attributes and format into SeedDatasetMetadata schema. + """ + valid_fields = [f.name for f in fields(SeedDatasetMetadata)] + self_metadata = {k: v for k, v in self.__dict__.items() if k in valid_fields} + if not self_metadata: + return None + return SeedDatasetMetadata(**self_metadata) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index cb7ea6fed7..ae4e33eb31 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -45,9 +45,6 @@ def __init_subclass__(cls, **kwargs: Any) -> None: if not inspect.isabstract(cls) and getattr(cls, "should_register", True): SeedDatasetProvider._registry[cls.__name__] = cls logger.debug(f"Registered dataset provider: {cls.__name__}") - # Providing metadata is optional. - if getattr(cls, "_metadata", True): - logger.debug(f"Dataset provider {cls.__name__} provided metadata.") @property @abstractmethod @@ -109,23 +106,20 @@ def get_all_dataset_names(cls, filters: Optional[SeedDatasetFilter] = None) -> l # Instantiate to get dataset name provider = provider_class() - # Extract metadata, default to False if not found - metadata = getattr(provider, "_metadata", False) - if filters and not metadata: + # Parser ensures a standard metadata format + metadata: SeedDatasetMetadata = cls._parse_metadata() + if filters and not metadata and "all" not in filters.tags: + # Datasets without metadata are skipped unless we want "all" continue - # Type safety for metadata object given getattr return type - if isinstance(metadata, bool): - raise ValueError - # Filters detected but no match -> don't add this dataset if filters and not cls._match_filter(metadata=metadata, filters=filters): continue - # This triggers when filters match (and filters exist) dataset_names.add(provider.dataset_name) except Exception as e: - raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e + raise ValueError( + f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) @classmethod @@ -153,6 +147,7 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter # These lines all disable SIM103 because metadata and filters tags can be optional, so # directly checking for membership breaks type checking. + if metadata.tags and filters.tags and not (filters.tags & metadata.tags): # noqa: SIM103 return False @@ -161,12 +156,22 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter return False # Harm Categories + if metadata.harm_categories and filters.harm_categories and \ + not set(metadata.harm_categories) & set(filters.harm_categories): # noqa: SIM103 + return False # Source Type + if metadata.source and filters.sources and metadata.source not in filters.sources: # noqa: SIM103 + return False # Modalities + if metadata.modalities and filters.modalities and \ + not set(metadata.modalities) & set(filters.modalities): # noqa: SIM103 + return False # Rank + if metadata.rank and filters.ranks and metadata.rank not in filters.ranks: # noqa: SIM103 + return False return True @@ -210,9 +215,11 @@ async def fetch_datasets_async( # Validate dataset names if specified if dataset_names is not None: available_names = cls.get_all_dataset_names() - invalid_names = [name for name in dataset_names if name not in available_names] + invalid_names = [ + name for name in dataset_names if name not in available_names] if invalid_names: - raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") + raise ValueError( + f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -238,7 +245,8 @@ async def fetch_single_dataset( # Progress tracking total_count = len(cls._registry) - pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") + pbar = tqdm(total=total_count, + desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -276,10 +284,12 @@ async def fetch_with_semaphore( logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] - combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) + combined_seeds = list( + existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset - logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") + logger.info( + f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values()) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 6b87f2fceb..6037d811b0 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -5,18 +5,12 @@ from enum import Enum from typing import Optional -""" -TODO Finish docstring +from pyrit.common.path import DATASETS_PATH +""" Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). -We have one DatasetMetadata dataclass that is our ground truth. As we instantiate datasets -using the subclass call in SeedDatasetProvider, we create DatasetMetadata and assign it to -a private variable there. - -Some fields are dynamic (e.g. loading statistics, timestamp, dataset size) and are left as -NoneType until the SeedDatasetProvider actually downloads/parses the dataset and puts it in -CentralMemory. +The ground truth is SeedDatasetMetadata. This is """ @@ -31,7 +25,9 @@ class SeedDatasetSize(Enum): class SeedDatasetLoadingRank(Enum): - """Represents the general difficulty of loading in a dataset.""" + """ + Represents the general difficulty of loading in a dataset. + """ DEFAULT = "default" EXTENDED = "extended" @@ -40,7 +36,7 @@ class SeedDatasetLoadingRank(Enum): class SeedDatasetModality(Enum): """ - ... + Type of data contained in the dataset. """ TEXT = "text" @@ -51,18 +47,18 @@ class SeedDatasetModality(Enum): class SeedDatasetSourceType(Enum): """ - ... + Where the dataset is pulled from. """ - GENERIC_URL = "generic_url" + REMOTE = "remote" LOCAL = "local" - HUGGING_FACE = "hugging_face" @dataclass class SeedDatasetFilter: """ - ... + Filter object for datasets. Passed to `get_all_dataset_names` in + SeedDatasetProvider. """ tags: Optional[set[str]] @@ -76,7 +72,8 @@ class SeedDatasetFilter: @dataclass(frozen=True) class SeedDatasetMetadata: """ - ... + Metadata object for datasets. Holds the same fields as the filter + object. """ tags: Optional[set[str]] @@ -89,7 +86,7 @@ class SeedDatasetMetadata: class SeedDatasetMetadataUtilities: """ - Collected utilities for managing and updating SeedDatasetMetadata. + Collected utilities for managing and updating metadata. """ @staticmethod @@ -98,5 +95,19 @@ def populate_metadata() -> None: WARNING: Because this function updates the metadata for each SeedDatasetProvider, it changes the provider's corresopnding source file. Run with caution! - Update the metadata per SeedDatasetProvider. + Updates the metadata per SeedDatasetProvider. """ + + # 1 Gather all dataset files + + # 2 For each file, download and store in the database (in-memory) + + # 3 Count the number of entries exactly and identify its threshold + + # 4 If harm categories are found in source, add them + + # 5 Inspect type of prompts to identify modalities present + + # 6 Inspect source file to find where it pulled from + + # 7 Leave rank optional for now diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index ceacc2a860..491f97e92a 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -61,6 +61,6 @@ async def test_fetch_dataset_integration(self, name, provider_cls): @pytest.mark.asyncio @pytest.mark.parameterize("name,provider_cls", get_dataset_providers()) - async def test_fetch_dataset_with_filtering(self, name, provider_cls): + async def test_fetch_dataset_integration_with_filtering(self, name, provider_cls): # TODO pass diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py index 7f38572311..5487a1c848 100644 --- a/tests/unit/datasets/test_seed_dataset_metadata.py +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -2,31 +2,90 @@ # Licensed under the MIT license. """ -TODO - -Tests for SeedDatasetMetadata +Tests for metadata components related to SeedDatasetProvider. """ -class TestMetadataParsing: - def test_invalid_filter_key(self): +class TestMetadataLifecycle: + """ + Test that the metadata object can be created with different + subsets of values. + """ + + def test_has_no_values(self): pass - def test_invalid_filter_value(self): + def test_has_some_values(self): pass + def test_has_all_values(self): + pass -class TestMetadataLifecycle: - def test_static_values_populated(self): + +class TestFilterLifecycle: + """ + Test that the metadata object can be created with different + subsets of values. + """ + + def test_has_no_values(self): + pass + + def test_has_some_values(self): pass - def test_dynamic_values_populated(self): + def test_has_all_values(self): pass -class TestMetadataPerformance: - def test_quick_retrieval_for_static_values(self): +class TestMetadataProperties: + """ + Test that the metadata fields populate correctly. + """ + + def test_size_value(self): + pass + + def test_loading_rank_value(self): + pass + + def test_source_value(self): + pass + + def test_modality_value(self): + pass + + def test_tags_value(self): pass - def test_acceptable_retrieval_for_dynamic_values(self): + def test_harm_categories_value(self): + pass + + +class TestFilterProperties: + """ + Test that the filter fields popualte correctly. + """ + + def test_sizes_values(self): + pass + + def test_loading_ranks_values(self): + pass + + def test_sources_values(self): + pass + + def test_modalities_values(self): + pass + + def test_tags_values(self): + pass + + def test_harm_categories_values(self): + pass + + +class TestMetadataUtilities: + def test_population_works(self, tmp_path): pass diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index d61e2291a2..0dbf0e13b7 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -78,13 +78,15 @@ async def test_fetch_datasets_async(self): mock_provider1 = MagicMock() mock_provider1.return_value.dataset_name = "d1" mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset( + seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock() mock_provider2.return_value.dataset_name = "d2" mock_provider2.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") + return_value=SeedDataset( + seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): @@ -97,12 +99,14 @@ async def test_fetch_datasets_async_with_filter(self): mock_provider1 = MagicMock() mock_provider1.return_value.dataset_name = "d1" mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset( + seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock() mock_provider2.return_value.dataset_name = "d2" - mock_provider2.return_value.fetch_dataset = AsyncMock(side_effect=Exception("Should not be called")) + mock_provider2.return_value.fetch_dataset = AsyncMock( + side_effect=Exception("Should not be called")) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): datasets = await SeedDatasetProvider.fetch_datasets_async(dataset_names=["d1"]) @@ -115,13 +119,15 @@ async def test_fetch_datasets_async_invalid_dataset_name(self): mock_provider1 = MagicMock() mock_provider1.return_value.dataset_name = "d1" mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset( + seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock() mock_provider2.return_value.dataset_name = "d2" mock_provider2.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") + return_value=SeedDataset( + seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): @@ -236,3 +242,60 @@ async def test_fetch_dataset_with_custom_config(self, mock_darkbench_data): assert call_kwargs["dataset_name"] == "custom/darkbench" assert call_kwargs["config"] == "custom_config" assert call_kwargs["split"] == "test" + + +class TestMetadataParsingRemote: + def test_all_tag(self): + pass + + def test_tags(self): + pass + + def test_sizes(self): + pass + + def test_modalities(self): + pass + + def test_sources(self): + pass + + def test_ranks(self): + pass + + def test_harm_categories(self): + pass + + def test_empty_fitler(self): + pass + + def test_no_metadata(self): + pass + +class TestMetadataParsingLocal: + def test_all_tag(self): + pass + + def test_tags(self): + pass + + def test_sizes(self): + pass + + def test_modalities(self): + pass + + def test_sources(self): + pass + + def test_ranks(self): + pass + + def test_harm_categories(self): + pass + + def test_empty_fitler(self): + pass + + def test_no_metadata(self): + pass From 8dcbd5ff44541bf8ff0276f3df79886c1a21405c Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 13 Mar 2026 20:02:20 +0000 Subject: [PATCH 06/29] review --- .../local/local_dataset_loader.py | 93 +++-- .../seed_datasets/remote/harmbench_dataset.py | 14 +- .../remote/remote_dataset_loader.py | 48 ++- .../seed_datasets/seed_dataset_provider.py | 57 +-- pyrit/datasets/seed_datasets/seed_metadata.py | 63 +-- .../test_seed_dataset_provider_integration.py | 283 +++++++++++++- .../datasets/test_seed_dataset_metadata.py | 126 ++++-- .../datasets/test_seed_dataset_provider.py | 360 +++++++++++++++--- 8 files changed, 830 insertions(+), 214 deletions(-) diff --git a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py index a54062c779..2c6dd4b778 100644 --- a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py @@ -2,14 +2,22 @@ # Licensed under the MIT license. import logging -import yaml from collections.abc import Callable +from dataclasses import fields from pathlib import Path -from typing import Any +from typing import Any, Optional + +import yaml from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetLoadingRank, + SeedDatasetMetadata, + SeedDatasetModality, + SeedDatasetSize, + SeedDatasetSourceType, +) from pyrit.models import SeedDataset -from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetMetadata logger = logging.getLogger(__name__) @@ -38,8 +46,7 @@ def __init__(self, *, file_path: Path): dataset = SeedDataset.from_yaml_file(file_path) # Use the dataset_name from the YAML if available, otherwise use filename self._dataset_name = ( - getattr(dataset, "dataset_name", None) or getattr( - dataset, "name", None) or file_path.stem + getattr(dataset, "dataset_name", None) or getattr(dataset, "name", None) or file_path.stem ) except Exception as e: logger.warning(f"Could not pre-load dataset from {file_path}: {e}") @@ -70,31 +77,66 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: dataset.dataset_name = self.dataset_name return dataset except Exception as e: - logger.error( - f"Failed to load local dataset from {self.file_path}: {e}") + logger.error(f"Failed to load local dataset from {self.file_path}: {e}") raise - def _parse_metadata(self) -> SeedDatasetMetadata | None: + def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: """ - Extract metadata from class attributes and format into SeedDatasetMetadata schema. - + Extract metadata from a local YAML file and coerce raw values into typed schema fields. + + YAML produces raw Python primitives (str, list) that must be converted to the + enum and set types expected by SeedDatasetMetadata before _match_filter can work. + + Returns: + Optional[SeedDatasetMetadata]: Parsed metadata if available, otherwise None. + Raises: - Exception: If the dataset cannot be loaded. + Exception: If the dataset file cannot be read. """ valid_fields = [f.name for f in fields(SeedDatasetMetadata)] try: - with open(self.file_path, 'r') as f: + with open(self.file_path, encoding="utf-8") as f: dataset = yaml.safe_load(f) except Exception as e: - logger.error( - f"Failed to load local datset from {self.file_path}: {e}" - ) + logger.error(f"Failed to load local dataset from {self.file_path}: {e}") raise - self_metadata = {k: v for k, v in dataset if k in valid_fields} - if not self_metadata: + + if not isinstance(dataset, dict): return None - return SeedDatasetMetadata(**self_metadata) + raw = {k: v for k, v in dataset.items() if k in valid_fields} + if not raw: + return None + + coerced = self._coerce_metadata_values(raw_metadata=raw) + return SeedDatasetMetadata(**coerced) + + @staticmethod + def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: + """ + Convert YAML primitive values into the enum/set types expected by SeedDatasetMetadata. + + Args: + raw_metadata (dict[str, Any]): Dictionary of field names to raw YAML-parsed values. + + Returns: + dict[str, Any]: Dictionary with values coerced to the correct types. + """ + coerced: dict[str, Any] = {} + for key, value in raw_metadata.items(): + if key == "tags" and isinstance(value, list): + coerced[key] = set(value) + elif key == "size" and isinstance(value, str): + coerced[key] = SeedDatasetSize(value) + elif key == "source" and isinstance(value, str): + coerced[key] = SeedDatasetSourceType(value) + elif key == "rank" and isinstance(value, str): + coerced[key] = SeedDatasetLoadingRank(value) + elif key == "modalities" and isinstance(value, list): + coerced[key] = [SeedDatasetModality(v) for v in value] + else: + coerced[key] = value + return coerced def _register_local_datasets() -> None: @@ -119,26 +161,21 @@ def _register_local_datasets() -> None: def make_init(path: Path) -> Callable[[Any], None]: def __init__(self: Any) -> None: # noqa: N807 - super(self.__class__, self).__init__( - file_path=path) + super(self.__class__, self).__init__(file_path=path) return __init__ type( class_name, (_LocalDatasetLoader,), - {"__init__": make_init( - yaml_file), "should_register": True, "__module__": __name__}, + {"__init__": make_init(yaml_file), "should_register": True, "__module__": __name__}, ) - logger.debug( - f"Registered local dataset loader: {class_name} for {yaml_file.name}") + logger.debug(f"Registered local dataset loader: {class_name} for {yaml_file.name}") except Exception as e: - logger.warning( - f"Failed to register local dataset {yaml_file}: {e}") + logger.warning(f"Failed to register local dataset {yaml_file}: {e}") else: - logger.warning( - f"Seed datasets directory not found: {seed_datasets_path}") + logger.warning(f"Seed datasets directory not found: {seed_datasets_path}") # Execute registration diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py index 4759d4b7ee..5d30fac2c4 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py @@ -7,11 +7,8 @@ _RemoteDatasetLoader, ) from pyrit.datasets.seed_datasets.seed_metadata import ( - SeedDatasetLoadingRank, - SeedDatasetMetadata, SeedDatasetModality, SeedDatasetSize, - SeedDatasetSourceType, ) from pyrit.models import SeedDataset, SeedObjective @@ -27,10 +24,9 @@ class _HarmBenchDataset(_RemoteDatasetLoader): """ # Metadata - harm_categories: list[str] = ["cybercrime", "illegal", - "harmful", "chemical_biological", "harassment"] + harm_categories: list[str] = ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] modalities: list[SeedDatasetModality] = [SeedDatasetModality.TEXT] - size: SeedDatasetSize = SeedDatasetSize.LARGE # 504 seeds + size: SeedDatasetSize = SeedDatasetSize.LARGE # 504 seeds # "default" means included in curated set tags: set[str] = {"default", "safety"} @@ -87,8 +83,7 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: # Check for missing keys in the example missing_keys = required_keys - example.keys() if missing_keys: - raise ValueError( - f"Missing keys in example: {', '.join(missing_keys)}") + raise ValueError(f"Missing keys in example: {', '.join(missing_keys)}") # Extract data category = example["SemanticCategory"] @@ -104,8 +99,7 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: "biological, illegal activities, etc." ), source="https://github.com/centerforaisafety/HarmBench", - authors=["Mantas Mazeika", "Long Phan", "Xuwang Yin", - "Andy Zou", "Zifan Wang", "Norman Mu"], + authors=["Mantas Mazeika", "Long Phan", "Xuwang Yin", "Andy Zou", "Zifan Wang", "Norman Mu"], ) seeds.append(seed_prompt) diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 2f234f451b..9587a743f0 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -76,8 +76,7 @@ def _validate_file_type(self, file_type: str) -> None: """ if file_type not in FILE_TYPE_HANDLERS: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError( - f"Invalid file_type. Expected one of: {valid_types}.") + raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") def _read_cache(self, *, cache_file: Path, file_type: str) -> list[dict[str, str]]: """ @@ -134,19 +133,15 @@ def _fetch_from_public_url(self, *, source: str, file_type: str) -> list[dict[st if file_type in FILE_TYPE_HANDLERS: if file_type == "json": return cast( - "list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"]( - io.StringIO(response.text)) + "list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text)) ) return cast( "list[dict[str, str]]", - FILE_TYPE_HANDLERS[file_type]["read"]( - io.StringIO("\n".join(response.text.splitlines()))), + FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO("\n".join(response.text.splitlines()))), ) valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError( - f"Invalid file_type. Expected one of: {valid_types}.") - raise Exception( - f"Failed to fetch examples from public URL. Status code: {response.status_code}") + raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") + raise Exception(f"Failed to fetch examples from public URL. Status code: {response.status_code}") def _fetch_from_file(self, *, source: str, file_type: str) -> list[dict[str, str]]: """ @@ -166,8 +161,7 @@ def _fetch_from_file(self, *, source: str, file_type: str) -> list[dict[str, str if file_type in FILE_TYPE_HANDLERS: return cast("list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](file)) valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError( - f"Invalid file_type. Expected one of: {valid_types}.") + raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") def _fetch_from_url( self, @@ -199,26 +193,21 @@ def _fetch_from_url( file_type = source.split(".")[-1] if file_type not in FILE_TYPE_HANDLERS: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError( - f"Invalid file_type. Expected one of: {valid_types}.") + raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") data_home = DB_DATA_PATH / "seed-prompt-entries" - cache_file = data_home / \ - self._get_cache_file_name(source=source, file_type=file_type) + cache_file = data_home / self._get_cache_file_name(source=source, file_type=file_type) if cache and cache_file.exists(): return self._read_cache(cache_file=cache_file, file_type=file_type) if source_type == "public_url": - examples = self._fetch_from_public_url( - source=source, file_type=file_type) + examples = self._fetch_from_public_url(source=source, file_type=file_type) elif source_type == "file": - examples = self._fetch_from_file( - source=source, file_type=file_type) + examples = self._fetch_from_file(source=source, file_type=file_type) if cache: - self._write_cache(cache_file=cache_file, - examples=examples, file_type=file_type) + self._write_cache(cache_file=cache_file, examples=examples, file_type=file_type) else: with tempfile.NamedTemporaryFile( delete=False, mode="w", suffix=f".{file_type}", encoding="utf-8" @@ -296,16 +285,23 @@ def _load_dataset_sync() -> Any: # Run the synchronous load_dataset in a thread pool to avoid blocking the event loop return await asyncio.to_thread(_load_dataset_sync) except Exception as e: - logger.error( - f"Failed to load HuggingFace dataset {dataset_name}: {e}") + logger.error(f"Failed to load HuggingFace dataset {dataset_name}: {e}") raise - def _parse_metadata(self) -> SeedDatasetMetadata | None: + def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: """ Extract metadata from class attributes and format into SeedDatasetMetadata schema. + + Returns: + Optional[SeedDatasetMetadata]: Parsed metadata if available, otherwise None. """ valid_fields = [f.name for f in fields(SeedDatasetMetadata)] - self_metadata = {k: v for k, v in self.__dict__.items() if k in valid_fields} + + provider_class = type(self) + self_metadata = { + key: getattr(provider_class, key) for key in valid_fields if getattr(provider_class, key, None) is not None + } + if not self_metadata: return None return SeedDatasetMetadata(**self_metadata) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index ae4e33eb31..ae45fcb200 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -72,6 +72,19 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: Exception: If the dataset cannot be fetched or processed. """ + def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: + """ + Parse provider-specific metadata into the shared schema. + + Subclasses can override this to source metadata from class attributes, + prompt files, or any other backing format. The default implementation + returns None, which means metadata is not available for this provider. + + Returns: + Optional[SeedDatasetMetadata]: Parsed metadata for this provider, or None. + """ + return None + @classmethod def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: """ @@ -107,19 +120,24 @@ def get_all_dataset_names(cls, filters: Optional[SeedDatasetFilter] = None) -> l provider = provider_class() # Parser ensures a standard metadata format - metadata: SeedDatasetMetadata = cls._parse_metadata() - if filters and not metadata and "all" not in filters.tags: + metadata = provider._parse_metadata() + + # "all" bypasses metadata filtering and returns every dataset. + if filters and filters.tags and "all" in filters.tags: + dataset_names.add(provider.dataset_name) + continue + + if filters and not metadata: # Datasets without metadata are skipped unless we want "all" continue # Filters detected but no match -> don't add this dataset - if filters and not cls._match_filter(metadata=metadata, filters=filters): + if filters and metadata and not cls._match_filter(metadata=metadata, filters=filters): continue dataset_names.add(provider.dataset_name) except Exception as e: - raise ValueError( - f"Could not get dataset name from {provider_class.__name__}: {e}") from e + raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) @classmethod @@ -140,9 +158,7 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter bool: Whether or not the filters match or not. """ # Tags - if metadata.tags and "all" in metadata.tags: - # This is the only condition that returns true, because we want the "all" - # tag to override everything else in the filter. + if filters.tags and "all" in filters.tags: return True # These lines all disable SIM103 because metadata and filters tags can be optional, so @@ -156,8 +172,11 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter return False # Harm Categories - if metadata.harm_categories and filters.harm_categories and \ - not set(metadata.harm_categories) & set(filters.harm_categories): # noqa: SIM103 + if ( + metadata.harm_categories + and filters.harm_categories + and not set(metadata.harm_categories) & set(filters.harm_categories) + ): # noqa: SIM103 return False # Source Type @@ -165,8 +184,7 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter return False # Modalities - if metadata.modalities and filters.modalities and \ - not set(metadata.modalities) & set(filters.modalities): # noqa: SIM103 + if metadata.modalities and filters.modalities and not set(metadata.modalities) & set(filters.modalities): # noqa: SIM103 return False # Rank @@ -215,11 +233,9 @@ async def fetch_datasets_async( # Validate dataset names if specified if dataset_names is not None: available_names = cls.get_all_dataset_names() - invalid_names = [ - name for name in dataset_names if name not in available_names] + invalid_names = [name for name in dataset_names if name not in available_names] if invalid_names: - raise ValueError( - f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") + raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -245,8 +261,7 @@ async def fetch_single_dataset( # Progress tracking total_count = len(cls._registry) - pbar = tqdm(total=total_count, - desc="Loading datasets - this can take a few minutes", unit="dataset") + pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -284,12 +299,10 @@ async def fetch_with_semaphore( logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] - combined_seeds = list( - existing_dataset.seeds) + list(dataset.seeds) + combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset - logger.info( - f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") + logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values()) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 6037d811b0..f5926eacb0 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -5,12 +5,16 @@ from enum import Enum from typing import Optional -from pyrit.common.path import DATASETS_PATH - """ Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). -The ground truth is SeedDatasetMetadata. This is +SeedDatasetMetadata is the internal schema used to normalize metadata fields +from different sources: +- Remote providers that declare metadata as class attributes +- Local prompt files that store metadata at the top level + +SeedDatasetFilter is the user-facing filter schema consumed by +SeedDatasetProvider.get_all_dataset_names(). """ @@ -61,12 +65,12 @@ class SeedDatasetFilter: SeedDatasetProvider. """ - tags: Optional[set[str]] - sizes: Optional[list[SeedDatasetSize]] - modalities: Optional[list[SeedDatasetModality]] - sources: Optional[list[SeedDatasetSourceType]] - ranks: Optional[list[SeedDatasetLoadingRank]] - harm_categories: Optional[list[str]] + tags: Optional[set[str]] = None + sizes: Optional[list[SeedDatasetSize]] = None + modalities: Optional[list[SeedDatasetModality]] = None + sources: Optional[list[SeedDatasetSourceType]] = None + ranks: Optional[list[SeedDatasetLoadingRank]] = None + harm_categories: Optional[list[str]] = None @dataclass(frozen=True) @@ -76,38 +80,9 @@ class SeedDatasetMetadata: object. """ - tags: Optional[set[str]] - size: Optional[SeedDatasetSize] - modalities: Optional[list[SeedDatasetModality]] - source: Optional[SeedDatasetSourceType] - rank: Optional[SeedDatasetLoadingRank] - harm_categories: Optional[list[str]] - - -class SeedDatasetMetadataUtilities: - """ - Collected utilities for managing and updating metadata. - """ - - @staticmethod - def populate_metadata() -> None: - """ - WARNING: Because this function updates the metadata for each SeedDatasetProvider, - it changes the provider's corresopnding source file. Run with caution! - - Updates the metadata per SeedDatasetProvider. - """ - - # 1 Gather all dataset files - - # 2 For each file, download and store in the database (in-memory) - - # 3 Count the number of entries exactly and identify its threshold - - # 4 If harm categories are found in source, add them - - # 5 Inspect type of prompts to identify modalities present - - # 6 Inspect source file to find where it pulled from - - # 7 Leave rank optional for now + tags: Optional[set[str]] = None + size: Optional[SeedDatasetSize] = None + modalities: Optional[list[SeedDatasetModality]] = None + source: Optional[SeedDatasetSourceType] = None + rank: Optional[SeedDatasetLoadingRank] = None + harm_categories: Optional[list[str]] = None diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index 491f97e92a..22d4261ae1 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -2,12 +2,21 @@ # Licensed under the MIT license. import logging +import textwrap +from pathlib import Path +from unittest.mock import patch import pytest from pyrit.datasets import SeedDatasetProvider +from pyrit.datasets.seed_datasets.local.local_dataset_loader import _LocalDatasetLoader from pyrit.datasets.seed_datasets.remote import _VLSUMultimodalDataset -from pyrit.models import SeedDataset +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetFilter, + SeedDatasetModality, + SeedDatasetSize, +) +from pyrit.models import SeedDataset, SeedPrompt logger = logging.getLogger(__name__) @@ -37,12 +46,10 @@ async def test_fetch_dataset_integration(self, name, provider_cls): try: # Use max_examples for slow providers that fetch many remote images - provider = provider_cls( - max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls() + provider = provider_cls(max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls() dataset = await provider.fetch_dataset(cache=False) - assert isinstance( - dataset, SeedDataset), f"{name} did not return a SeedDataset" + assert isinstance(dataset, SeedDataset), f"{name} did not return a SeedDataset" assert len(dataset.seeds) > 0, f"{name} returned an empty dataset" assert dataset.dataset_name, f"{name} has no dataset_name" @@ -53,14 +60,266 @@ async def test_fetch_dataset_integration(self, name, provider_cls): f"Seed dataset_name mismatch in {name}: {seed.dataset_name} != {dataset.dataset_name}" ) - logger.info( - f"Successfully verified {name} with {len(dataset.seeds)} seeds") + logger.info(f"Successfully verified {name} with {len(dataset.seeds)} seeds") except Exception as e: pytest.fail(f"Failed to fetch dataset from {name}: {str(e)}") - @pytest.mark.asyncio - @pytest.mark.parameterize("name,provider_cls", get_dataset_providers()) - async def test_fetch_dataset_integration_with_filtering(self, name, provider_cls): - # TODO - pass + +class TestRemoteFilteringIntegration: + """ + Integration test for remote dataset filtering. + + Uses a mocked remote provider with class-level metadata attributes to + validate the full flow: metadata population, filter matching, and + get_all_dataset_names output. + """ + + def _make_remote_provider_cls( + self, + *, + name: str, + tags: set, + size: SeedDatasetSize, + modalities: list, + harm_categories: list, + ) -> type: + """Build a minimal concrete SeedDatasetProvider with class-level metadata.""" + from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import _RemoteDatasetLoader + + captured_name = name + + async def _fetch_dataset(self, *, cache=True): + return SeedDataset( + seeds=[SeedPrompt(value="x", data_type="text")], + dataset_name=captured_name, + ) + + attrs = { + "tags": tags, + "size": size, + "modalities": modalities, + "harm_categories": harm_categories, + "should_register": False, + "__module__": __name__, + # Concrete implementations satisfy ABC requirements + "dataset_name": property(lambda self: captured_name), + "fetch_dataset": _fetch_dataset, + "_fetch_from_url": lambda self, **kw: [], + } + + return type(f"_Mock_{name}", (_RemoteDatasetLoader,), attrs) + + def test_filter_matches_correct_remote_provider(self): + """Filter by size returns only providers that match.""" + large_cls = self._make_remote_provider_cls( + name="large_ds", + tags={"default"}, + size=SeedDatasetSize.LARGE, + modalities=[SeedDatasetModality.TEXT], + harm_categories=["violence"], + ) + small_cls = self._make_remote_provider_cls( + name="small_ds", + tags={"default"}, + size=SeedDatasetSize.SMALL, + modalities=[SeedDatasetModality.TEXT], + harm_categories=["cybercrime"], + ) + + with patch.dict( + SeedDatasetProvider._registry, + {"Large": large_cls, "Small": small_cls}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]), + ) + assert names == ["large_ds"] + + def test_filter_all_tag_returns_everything(self): + """tags={'all'} bypasses filtering and returns every provider.""" + cls1 = self._make_remote_provider_cls( + name="ds_a", + tags={"safety"}, + size=SeedDatasetSize.TINY, + modalities=[SeedDatasetModality.TEXT], + harm_categories=[], + ) + cls2 = self._make_remote_provider_cls( + name="ds_b", + tags={"custom"}, + size=SeedDatasetSize.HUGE, + modalities=[SeedDatasetModality.IMAGE], + harm_categories=["violence"], + ) + + with patch.dict( + SeedDatasetProvider._registry, + {"A": cls1, "B": cls2}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"all"}), + ) + assert sorted(names) == ["ds_a", "ds_b"] + + def test_multi_axis_filter(self): + """Multiple filter axes are ANDed together.""" + cls1 = self._make_remote_provider_cls( + name="text_large", + tags={"default"}, + size=SeedDatasetSize.LARGE, + modalities=[SeedDatasetModality.TEXT], + harm_categories=["violence"], + ) + cls2 = self._make_remote_provider_cls( + name="image_large", + tags={"default"}, + size=SeedDatasetSize.LARGE, + modalities=[SeedDatasetModality.IMAGE], + harm_categories=["violence"], + ) + + with patch.dict( + SeedDatasetProvider._registry, + {"TL": cls1, "IL": cls2}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter( + sizes=[SeedDatasetSize.LARGE], + modalities=[SeedDatasetModality.TEXT], + ), + ) + assert names == ["text_large"] + + +class TestLocalFilteringIntegration: + """ + Integration test for local dataset filtering. + + Creates real YAML prompt files on disk, registers them as local providers, + and validates the full flow through get_all_dataset_names with filters. + """ + + @staticmethod + def _make_local_cls(yaml_path: Path) -> type: + """Build a dynamic local provider class for a YAML file.""" + + def make_init(path: Path): + def init_fn(self): + _LocalDatasetLoader.__init__(self, file_path=path) + + return init_fn + + return type( + f"LocalTest_{yaml_path.stem}", + (_LocalDatasetLoader,), + {"__init__": make_init(yaml_path), "should_register": False, "__module__": __name__}, + ) + + def test_local_filter_by_size(self, tmp_path): + """Local YAML with size metadata is correctly coerced and filtered.""" + large_yaml = tmp_path / "large_ds.prompt" + large_yaml.write_text( + textwrap.dedent("""\ + dataset_name: large_local + size: large + harm_categories: + - violence + seeds: + - value: test + data_type: text + """) + ) + small_yaml = tmp_path / "small_ds.prompt" + small_yaml.write_text( + textwrap.dedent("""\ + dataset_name: small_local + size: small + harm_categories: + - cybercrime + seeds: + - value: test + data_type: text + """) + ) + + large_cls = self._make_local_cls(large_yaml) + small_cls = self._make_local_cls(small_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Large": large_cls, "Small": small_cls}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]), + ) + # dataset_name falls back to file stem when SeedDataset.from_yaml_file + # rejects extra keys like "size" during __init__ pre-loading + assert names == ["large_ds"] + + def test_local_filter_by_tags(self, tmp_path): + """Local YAML tags (list) are coerced to set for intersection.""" + yaml_path = tmp_path / "tagged.prompt" + yaml_path.write_text( + textwrap.dedent("""\ + dataset_name: tagged_local + tags: + - safety + - default + harm_categories: + - violence + seeds: + - value: test + data_type: text + """) + ) + cls = self._make_local_cls(yaml_path) + + with patch.dict( + SeedDatasetProvider._registry, + {"Tagged": cls}, + clear=True, + ): + # dataset_name falls back to file stem ("tagged") when + # SeedDataset.from_yaml_file rejects extra keys like "tags" + matched = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert matched == ["tagged"] + + not_matched = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"unrelated"}), + ) + assert not_matched == [] + + def test_local_no_metadata_skipped(self, tmp_path): + """Local YAML without metadata fields is skipped when filters are provided.""" + yaml_path = tmp_path / "bare.prompt" + yaml_path.write_text( + textwrap.dedent("""\ + dataset_name: bare_local + seeds: + - value: test + data_type: text + """) + ) + cls = self._make_local_cls(yaml_path) + + with patch.dict( + SeedDatasetProvider._registry, + {"Bare": cls}, + clear=True, + ): + # Without filters, the dataset is included + all_names = SeedDatasetProvider.get_all_dataset_names() + assert "bare_local" in all_names + + # With filters, it's skipped (no metadata to match against) + filtered = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert filtered == [] diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py index 5487a1c848..a99fe2450d 100644 --- a/tests/unit/datasets/test_seed_dataset_metadata.py +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -5,6 +5,15 @@ Tests for metadata components related to SeedDatasetProvider. """ +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetFilter, + SeedDatasetLoadingRank, + SeedDatasetMetadata, + SeedDatasetModality, + SeedDatasetSize, + SeedDatasetSourceType, +) + class TestMetadataLifecycle: """ @@ -13,29 +22,76 @@ class TestMetadataLifecycle: """ def test_has_no_values(self): - pass + metadata = SeedDatasetMetadata() + assert metadata.tags is None + assert metadata.size is None + assert metadata.modalities is None + assert metadata.source is None + assert metadata.rank is None + assert metadata.harm_categories is None def test_has_some_values(self): - pass + metadata = SeedDatasetMetadata(tags={"safety"}, size=SeedDatasetSize.LARGE) + assert metadata.tags == {"safety"} + assert metadata.size == SeedDatasetSize.LARGE + assert metadata.modalities is None + assert metadata.source is None + assert metadata.rank is None + assert metadata.harm_categories is None def test_has_all_values(self): - pass + metadata = SeedDatasetMetadata( + tags={"default", "safety"}, + size=SeedDatasetSize.MEDIUM, + modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE], + source=SeedDatasetSourceType.REMOTE, + rank=SeedDatasetLoadingRank.DEFAULT, + harm_categories=["violence", "illegal"], + ) + assert metadata.tags == {"default", "safety"} + assert metadata.size == SeedDatasetSize.MEDIUM + assert len(metadata.modalities) == 2 + assert metadata.source == SeedDatasetSourceType.REMOTE + assert metadata.rank == SeedDatasetLoadingRank.DEFAULT + assert metadata.harm_categories == ["violence", "illegal"] class TestFilterLifecycle: """ - Test that the metadata object can be created with different + Test that the filter object can be created with different subsets of values. """ def test_has_no_values(self): - pass + f = SeedDatasetFilter() + assert f.tags is None + assert f.sizes is None + assert f.modalities is None + assert f.sources is None + assert f.ranks is None + assert f.harm_categories is None def test_has_some_values(self): - pass + f = SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]) + assert f.sizes == [SeedDatasetSize.LARGE] + assert f.tags is None + assert f.modalities is None def test_has_all_values(self): - pass + f = SeedDatasetFilter( + tags={"default"}, + sizes=[SeedDatasetSize.SMALL, SeedDatasetSize.MEDIUM], + modalities=[SeedDatasetModality.TEXT], + sources=[SeedDatasetSourceType.REMOTE], + ranks=[SeedDatasetLoadingRank.DEFAULT], + harm_categories=["violence"], + ) + assert f.tags == {"default"} + assert len(f.sizes) == 2 + assert f.modalities == [SeedDatasetModality.TEXT] + assert f.sources == [SeedDatasetSourceType.REMOTE] + assert f.ranks == [SeedDatasetLoadingRank.DEFAULT] + assert f.harm_categories == ["violence"] class TestMetadataProperties: @@ -44,48 +100,68 @@ class TestMetadataProperties: """ def test_size_value(self): - pass + for size in SeedDatasetSize: + metadata = SeedDatasetMetadata(size=size) + assert metadata.size == size def test_loading_rank_value(self): - pass + for rank in SeedDatasetLoadingRank: + metadata = SeedDatasetMetadata(rank=rank) + assert metadata.rank == rank def test_source_value(self): - pass + for source in SeedDatasetSourceType: + metadata = SeedDatasetMetadata(source=source) + assert metadata.source == source def test_modality_value(self): - pass + for modality in SeedDatasetModality: + metadata = SeedDatasetMetadata(modalities=[modality]) + assert modality in metadata.modalities def test_tags_value(self): - pass + metadata = SeedDatasetMetadata(tags={"safety", "default", "custom"}) + assert "safety" in metadata.tags + assert "default" in metadata.tags + assert "custom" in metadata.tags def test_harm_categories_value(self): - pass + metadata = SeedDatasetMetadata(harm_categories=["violence", "cybercrime"]) + assert "violence" in metadata.harm_categories + assert "cybercrime" in metadata.harm_categories class TestFilterProperties: """ - Test that the filter fields popualte correctly. + Test that the filter fields populate correctly. """ def test_sizes_values(self): - pass + f = SeedDatasetFilter(sizes=[SeedDatasetSize.SMALL, SeedDatasetSize.LARGE]) + assert SeedDatasetSize.SMALL in f.sizes + assert SeedDatasetSize.LARGE in f.sizes def test_loading_ranks_values(self): - pass + f = SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT, SeedDatasetLoadingRank.SLOW]) + assert SeedDatasetLoadingRank.DEFAULT in f.ranks + assert SeedDatasetLoadingRank.SLOW in f.ranks def test_sources_values(self): - pass + f = SeedDatasetFilter(sources=[SeedDatasetSourceType.LOCAL, SeedDatasetSourceType.REMOTE]) + assert SeedDatasetSourceType.LOCAL in f.sources + assert SeedDatasetSourceType.REMOTE in f.sources def test_modalities_values(self): - pass + f = SeedDatasetFilter(modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE]) + assert SeedDatasetModality.TEXT in f.modalities + assert SeedDatasetModality.IMAGE in f.modalities def test_tags_values(self): - pass + f = SeedDatasetFilter(tags={"safety", "default"}) + assert "safety" in f.tags + assert "default" in f.tags def test_harm_categories_values(self): - pass - - -class TestMetadataUtilities: - def test_population_works(self, tmp_path): - pass + f = SeedDatasetFilter(harm_categories=["violence", "cybercrime"]) + assert "violence" in f.harm_categories + assert "cybercrime" in f.harm_categories diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index 0dbf0e13b7..b24d8b56b7 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -1,13 +1,23 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import textwrap from unittest.mock import AsyncMock, MagicMock, patch import pytest from pyrit.datasets import SeedDatasetProvider +from pyrit.datasets.seed_datasets.local.local_dataset_loader import _LocalDatasetLoader from pyrit.datasets.seed_datasets.remote.darkbench_dataset import _DarkBenchDataset from pyrit.datasets.seed_datasets.remote.harmbench_dataset import _HarmBenchDataset +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetFilter, + SeedDatasetLoadingRank, + SeedDatasetMetadata, + SeedDatasetModality, + SeedDatasetSize, + SeedDatasetSourceType, +) from pyrit.models import SeedDataset, SeedObjective, SeedPrompt @@ -78,15 +88,13 @@ async def test_fetch_datasets_async(self): mock_provider1 = MagicMock() mock_provider1.return_value.dataset_name = "d1" mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset( - seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock() mock_provider2.return_value.dataset_name = "d2" mock_provider2.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset( - seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") + return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): @@ -99,14 +107,12 @@ async def test_fetch_datasets_async_with_filter(self): mock_provider1 = MagicMock() mock_provider1.return_value.dataset_name = "d1" mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset( - seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock() mock_provider2.return_value.dataset_name = "d2" - mock_provider2.return_value.fetch_dataset = AsyncMock( - side_effect=Exception("Should not be called")) + mock_provider2.return_value.fetch_dataset = AsyncMock(side_effect=Exception("Should not be called")) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): datasets = await SeedDatasetProvider.fetch_datasets_async(dataset_names=["d1"]) @@ -119,15 +125,13 @@ async def test_fetch_datasets_async_invalid_dataset_name(self): mock_provider1 = MagicMock() mock_provider1.return_value.dataset_name = "d1" mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset( - seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock() mock_provider2.return_value.dataset_name = "d2" mock_provider2.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset( - seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") + return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): @@ -245,57 +249,319 @@ async def test_fetch_dataset_with_custom_config(self, mock_darkbench_data): class TestMetadataParsingRemote: + """Test metadata parsing and filter matching for remote providers.""" + + def test_parse_metadata_from_class_attrs(self): + """Test _parse_metadata correctly extracts class-level metadata attributes.""" + loader = _HarmBenchDataset() + metadata = loader._parse_metadata() + assert metadata is not None + assert metadata.tags == {"default", "safety"} + assert metadata.size == SeedDatasetSize.LARGE + assert metadata.modalities == [SeedDatasetModality.TEXT] + assert metadata.harm_categories == ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] + # source and rank are not declared as class attributes on HarmBench + assert metadata.source is None + assert metadata.rank is None + def test_all_tag(self): - pass + """Filter with tags={'all'} matches any metadata.""" + metadata = SeedDatasetMetadata(tags={"safety"}) + filters = SeedDatasetFilter(tags={"all"}) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) def test_tags(self): - pass + """Tag filter uses set intersection.""" + metadata = SeedDatasetMetadata(tags={"safety", "default"}) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=SeedDatasetFilter(tags={"safety"})) + assert not SeedDatasetProvider._match_filter(metadata=metadata, filters=SeedDatasetFilter(tags={"unrelated"})) def test_sizes(self): - pass + """Size filter checks membership in the sizes list.""" + metadata = SeedDatasetMetadata(size=SeedDatasetSize.LARGE) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE, SeedDatasetSize.HUGE]), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(sizes=[SeedDatasetSize.SMALL]), + ) def test_modalities(self): - pass + """Modality filter uses set intersection.""" + metadata = SeedDatasetMetadata(modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE]) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(modalities=[SeedDatasetModality.TEXT]), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(modalities=[SeedDatasetModality.AUDIO]), + ) def test_sources(self): - pass + """Source filter checks membership.""" + metadata = SeedDatasetMetadata(source=SeedDatasetSourceType.REMOTE) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(sources=[SeedDatasetSourceType.REMOTE]), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(sources=[SeedDatasetSourceType.LOCAL]), + ) def test_ranks(self): - pass + """Rank filter checks membership.""" + metadata = SeedDatasetMetadata(rank=SeedDatasetLoadingRank.DEFAULT) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT]), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.SLOW]), + ) def test_harm_categories(self): - pass + """Harm category filter uses set intersection.""" + metadata = SeedDatasetMetadata(harm_categories=["violence", "cybercrime"]) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(harm_categories=["violence"]), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(harm_categories=["unrelated"]), + ) - def test_empty_fitler(self): - pass + def test_empty_filter(self): + """Empty filter (all None) matches any metadata.""" + metadata = SeedDatasetMetadata(tags={"safety"}, size=SeedDatasetSize.LARGE) + filters = SeedDatasetFilter() + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) def test_no_metadata(self): - pass - -class TestMetadataParsingLocal: - def test_all_tag(self): - pass - - def test_tags(self): - pass - - def test_sizes(self): - pass - - def test_modalities(self): - pass - - def test_sources(self): - pass - - def test_ranks(self): - pass + """Provider without metadata is skipped when filters are applied.""" + mock_provider_cls = MagicMock() + mock_provider_instance = mock_provider_cls.return_value + mock_provider_instance.dataset_name = "no_metadata" + mock_provider_instance._parse_metadata.return_value = None - def test_harm_categories(self): - pass + with patch.dict(SeedDatasetProvider._registry, {"NoProv": mock_provider_cls}, clear=True): + names = SeedDatasetProvider.get_all_dataset_names(filters=SeedDatasetFilter(tags={"safety"})) + assert names == [] - def test_empty_fitler(self): - pass - def test_no_metadata(self): - pass +class TestMetadataParsingLocal: + """Test metadata parsing and filter matching for local YAML providers.""" + + def _make_loader(self, yaml_path): + """Create a _LocalDatasetLoader bypassing SeedDataset pre-loading.""" + loader = _LocalDatasetLoader.__new__(_LocalDatasetLoader) + loader.file_path = yaml_path + loader._dataset_name = yaml_path.stem + return loader + + def _write_yaml(self, tmp_path, name, content): + """Write a .prompt YAML file and return its path.""" + path = tmp_path / f"{name}.prompt" + path.write_text(content) + return path + + def test_parse_metadata_extracts_fields(self, tmp_path): + """Test _parse_metadata correctly extracts metadata fields from YAML.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + harm_categories: + - violence + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + assert metadata.harm_categories == ["violence"] + + def test_all_tag(self, tmp_path): + """Filter with tags={'all'} matches regardless of metadata types.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + tags: + - safety + harm_categories: + - violence + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(tags={"all"}) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_tags(self, tmp_path): + """YAML produces tags as list; set intersection in _match_filter expects a set.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + tags: + - safety + - default + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(tags={"safety"}) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_sizes(self, tmp_path): + """YAML produces size as string; _match_filter compares against enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + size: large + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_modalities(self, tmp_path): + """YAML produces modalities as list of strings; _match_filter uses enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + modalities: + - text + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(modalities=[SeedDatasetModality.TEXT]) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_sources(self, tmp_path): + """YAML produces source as string; _match_filter compares against enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + source: remote + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(sources=[SeedDatasetSourceType.REMOTE]) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_ranks(self, tmp_path): + """YAML produces rank as string; _match_filter compares against enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + rank: default + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT]) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_harm_categories(self, tmp_path): + """Both YAML and filter use list[str], so intersection works correctly.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + harm_categories: + - violence + - cybercrime + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(harm_categories=["violence"]) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_empty_filter(self, tmp_path): + """Empty filter matches everything.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + harm_categories: + - violence + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter() + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_no_metadata(self, tmp_path): + """YAML without any metadata fields returns None from _parse_metadata.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is None From 32b6752beeb71377b7f73c0a17cb741ea4e6e760 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 13 Mar 2026 20:34:15 +0000 Subject: [PATCH 07/29] tests --- .../local/local_dataset_loader.py | 6 +- .../seed_datasets/seed_dataset_provider.py | 2 +- pyrit/datasets/seed_datasets/seed_metadata.py | 4 +- .../test_seed_dataset_provider_integration.py | 269 ++++++++++++++++++ .../datasets/test_seed_dataset_metadata.py | 26 +- .../datasets/test_seed_dataset_provider.py | 95 ++++++- 6 files changed, 378 insertions(+), 24 deletions(-) diff --git a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py index 2c6dd4b778..1ef745f628 100644 --- a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py @@ -128,12 +128,16 @@ def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: coerced[key] = set(value) elif key == "size" and isinstance(value, str): coerced[key] = SeedDatasetSize(value) - elif key == "source" and isinstance(value, str): + elif key == "source_type" and isinstance(value, str): coerced[key] = SeedDatasetSourceType(value) elif key == "rank" and isinstance(value, str): coerced[key] = SeedDatasetLoadingRank(value) elif key == "modalities" and isinstance(value, list): coerced[key] = [SeedDatasetModality(v) for v in value] + elif key == "harm_categories" and isinstance(value, str): + coerced[key] = [value] + elif key == "tags" and isinstance(value, str): + coerced[key] = {value} else: coerced[key] = value return coerced diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index ae45fcb200..4d65b932c2 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -180,7 +180,7 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter return False # Source Type - if metadata.source and filters.sources and metadata.source not in filters.sources: # noqa: SIM103 + if metadata.source_type and filters.source_types and metadata.source_type not in filters.source_types: # noqa: SIM103 return False # Modalities diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index f5926eacb0..01c97b8934 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -68,7 +68,7 @@ class SeedDatasetFilter: tags: Optional[set[str]] = None sizes: Optional[list[SeedDatasetSize]] = None modalities: Optional[list[SeedDatasetModality]] = None - sources: Optional[list[SeedDatasetSourceType]] = None + source_types: Optional[list[SeedDatasetSourceType]] = None ranks: Optional[list[SeedDatasetLoadingRank]] = None harm_categories: Optional[list[str]] = None @@ -83,6 +83,6 @@ class SeedDatasetMetadata: tags: Optional[set[str]] = None size: Optional[SeedDatasetSize] = None modalities: Optional[list[SeedDatasetModality]] = None - source: Optional[SeedDatasetSourceType] = None + source_type: Optional[SeedDatasetSourceType] = None rank: Optional[SeedDatasetLoadingRank] = None harm_categories: Optional[list[str]] = None diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index 22d4261ae1..f2da3a292a 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -323,3 +323,272 @@ def test_local_no_metadata_skipped(self, tmp_path): filters=SeedDatasetFilter(tags={"safety"}), ) assert filtered == [] + + +class TestEndToEndLocalDatasetWorkflow: + """ + End-to-end integration test that simulates a user workflow: + 1. Write a temporary local dataset to disk + 2. Register it into the provider registry + 3. Discover it via filtered get_all_dataset_names + 4. Fetch the actual dataset by name + 5. Verify metadata is accessible and correct + """ + + @staticmethod + def _make_local_cls(yaml_path: Path) -> type: + """Build a dynamic local provider class for a YAML file.""" + + def make_init(path: Path): + def init_fn(self): + _LocalDatasetLoader.__init__(self, file_path=path) + + return init_fn + + return type( + f"LocalTest_{yaml_path.stem}", + (_LocalDatasetLoader,), + {"__init__": make_init(yaml_path), "should_register": False, "__module__": __name__}, + ) + + @pytest.mark.asyncio + async def test_user_discovers_and_fetches_filtered_dataset(self, tmp_path): + """ + Simulate a user who wants small text datasets about cybercrime: + 1. Filter returns the matching dataset name + 2. Fetch by that name succeeds and returns correct seeds + 3. Metadata on the provider is accessible and consistent + """ + # --- Step 0: Write two datasets to disk --- + matching_yaml = tmp_path / "cyber_small.prompt" + matching_yaml.write_text( + textwrap.dedent("""\ + dataset_name: cyber_prompts + harm_categories: + - cybercrime + seeds: + - value: Explain phishing techniques + data_type: text + - value: Describe SQL injection + data_type: text + """) + ) + unrelated_yaml = tmp_path / "violence_large.prompt" + unrelated_yaml.write_text( + textwrap.dedent("""\ + dataset_name: violence_prompts + harm_categories: + - violence + seeds: + - value: Describe a violent scenario + data_type: text + """) + ) + + matching_cls = self._make_local_cls(matching_yaml) + unrelated_cls = self._make_local_cls(unrelated_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Cyber": matching_cls, "Violence": unrelated_cls}, + clear=True, + ): + # --- Step 1: User filters by harm_categories --- + names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(harm_categories=["cybercrime"]), + ) + assert len(names) == 1 + dataset_name = names[0] + + # --- Step 2: User fetches the dataset by name --- + datasets = await SeedDatasetProvider.fetch_datasets_async( + dataset_names=[dataset_name], + ) + assert len(datasets) == 1 + dataset = datasets[0] + assert len(dataset.seeds) == 2 + assert dataset.seeds[0].value == "Explain phishing techniques" + assert dataset.seeds[1].value == "Describe SQL injection" + + # --- Step 3: User inspects metadata --- + provider = matching_cls() + metadata = provider._parse_metadata() + assert metadata is not None + assert metadata.harm_categories == ["cybercrime"] + + @pytest.mark.asyncio + async def test_user_fetches_unfiltered(self, tmp_path): + """ + Without filters, get_all_dataset_names returns everything, + and fetch_datasets_async retrieves all of them. + """ + ds1 = tmp_path / "ds_one.prompt" + ds1.write_text( + textwrap.dedent("""\ + dataset_name: dataset_one + seeds: + - value: prompt one + data_type: text + """) + ) + ds2 = tmp_path / "ds_two.prompt" + ds2.write_text( + textwrap.dedent("""\ + dataset_name: dataset_two + seeds: + - value: prompt two + data_type: text + """) + ) + + cls1 = self._make_local_cls(ds1) + cls2 = self._make_local_cls(ds2) + + with patch.dict( + SeedDatasetProvider._registry, + {"One": cls1, "Two": cls2}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names() + assert len(names) == 2 + + datasets = await SeedDatasetProvider.fetch_datasets_async() + assert len(datasets) == 2 + fetched_names = sorted(d.dataset_name for d in datasets) + assert fetched_names == ["dataset_one", "dataset_two"] + + +class TestAllTagBypassIntegration: + """ + Integration tests for the tags={'all'} bypass pattern. + + The 'all' tag is a special escape hatch that returns every registered + dataset regardless of metadata presence or other filter axes. + """ + + @staticmethod + def _make_local_cls(yaml_path: Path) -> type: + """Build a dynamic local provider class for a YAML file.""" + + def make_init(path: Path): + def init_fn(self): + _LocalDatasetLoader.__init__(self, file_path=path) + + return init_fn + + return type( + f"LocalTest_{yaml_path.stem}", + (_LocalDatasetLoader,), + {"__init__": make_init(yaml_path), "should_register": False, "__module__": __name__}, + ) + + def test_all_tag_includes_datasets_without_metadata(self, tmp_path): + """ + A dataset whose YAML has no metadata fields at all is normally + skipped when filters are present. tags={'all'} overrides that. + """ + bare_yaml = tmp_path / "bare.prompt" + bare_yaml.write_text( + textwrap.dedent("""\ + dataset_name: bare_dataset + seeds: + - value: bare prompt + data_type: text + """) + ) + cls = self._make_local_cls(bare_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Bare": cls}, + clear=True, + ): + # Normal filter skips it + filtered = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert filtered == [] + + # 'all' includes it + all_names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"all"}), + ) + assert "bare_dataset" in all_names + + def test_all_tag_ignores_other_filter_axes(self, tmp_path): + """ + tags={'all'} returns everything even when other filter axes + would exclude datasets. + """ + small_yaml = tmp_path / "small.prompt" + small_yaml.write_text( + textwrap.dedent("""\ + dataset_name: small_dataset + size: small + harm_categories: + - cybercrime + seeds: + - value: small prompt + data_type: text + """) + ) + cls = self._make_local_cls(small_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Small": cls}, + clear=True, + ): + # Size filter alone would exclude it + size_filtered = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]), + ) + assert size_filtered == [] + + # 'all' tag overrides the size filter + all_names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"all"}, sizes=[SeedDatasetSize.LARGE]), + ) + assert "small" in all_names + + def test_all_tag_with_mixed_metadata_and_bare_datasets(self, tmp_path): + """ + With a mix of metadata-rich and metadata-bare datasets, + tags={'all'} returns all of them. + """ + rich_yaml = tmp_path / "rich.prompt" + rich_yaml.write_text( + textwrap.dedent("""\ + dataset_name: rich_dataset + harm_categories: + - violence + tags: + - safety + seeds: + - value: rich prompt + data_type: text + """) + ) + bare_yaml = tmp_path / "bare.prompt" + bare_yaml.write_text( + textwrap.dedent("""\ + dataset_name: bare_dataset + seeds: + - value: bare prompt + data_type: text + """) + ) + + rich_cls = self._make_local_cls(rich_yaml) + bare_cls = self._make_local_cls(bare_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Rich": rich_cls, "Bare": bare_cls}, + clear=True, + ): + all_names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"all"}), + ) + assert len(all_names) == 2 + assert "bare_dataset" in all_names diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py index a99fe2450d..4aaaed1fd4 100644 --- a/tests/unit/datasets/test_seed_dataset_metadata.py +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -26,7 +26,7 @@ def test_has_no_values(self): assert metadata.tags is None assert metadata.size is None assert metadata.modalities is None - assert metadata.source is None + assert metadata.source_type is None assert metadata.rank is None assert metadata.harm_categories is None @@ -35,7 +35,7 @@ def test_has_some_values(self): assert metadata.tags == {"safety"} assert metadata.size == SeedDatasetSize.LARGE assert metadata.modalities is None - assert metadata.source is None + assert metadata.source_type is None assert metadata.rank is None assert metadata.harm_categories is None @@ -44,14 +44,14 @@ def test_has_all_values(self): tags={"default", "safety"}, size=SeedDatasetSize.MEDIUM, modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE], - source=SeedDatasetSourceType.REMOTE, + source_type=SeedDatasetSourceType.REMOTE, rank=SeedDatasetLoadingRank.DEFAULT, harm_categories=["violence", "illegal"], ) assert metadata.tags == {"default", "safety"} assert metadata.size == SeedDatasetSize.MEDIUM assert len(metadata.modalities) == 2 - assert metadata.source == SeedDatasetSourceType.REMOTE + assert metadata.source_type == SeedDatasetSourceType.REMOTE assert metadata.rank == SeedDatasetLoadingRank.DEFAULT assert metadata.harm_categories == ["violence", "illegal"] @@ -67,7 +67,7 @@ def test_has_no_values(self): assert f.tags is None assert f.sizes is None assert f.modalities is None - assert f.sources is None + assert f.source_types is None assert f.ranks is None assert f.harm_categories is None @@ -82,14 +82,14 @@ def test_has_all_values(self): tags={"default"}, sizes=[SeedDatasetSize.SMALL, SeedDatasetSize.MEDIUM], modalities=[SeedDatasetModality.TEXT], - sources=[SeedDatasetSourceType.REMOTE], + source_types=[SeedDatasetSourceType.REMOTE], ranks=[SeedDatasetLoadingRank.DEFAULT], harm_categories=["violence"], ) assert f.tags == {"default"} assert len(f.sizes) == 2 assert f.modalities == [SeedDatasetModality.TEXT] - assert f.sources == [SeedDatasetSourceType.REMOTE] + assert f.source_types == [SeedDatasetSourceType.REMOTE] assert f.ranks == [SeedDatasetLoadingRank.DEFAULT] assert f.harm_categories == ["violence"] @@ -110,9 +110,9 @@ def test_loading_rank_value(self): assert metadata.rank == rank def test_source_value(self): - for source in SeedDatasetSourceType: - metadata = SeedDatasetMetadata(source=source) - assert metadata.source == source + for source_type in SeedDatasetSourceType: + metadata = SeedDatasetMetadata(source_type=source_type) + assert metadata.source_type == source_type def test_modality_value(self): for modality in SeedDatasetModality: @@ -147,9 +147,9 @@ def test_loading_ranks_values(self): assert SeedDatasetLoadingRank.SLOW in f.ranks def test_sources_values(self): - f = SeedDatasetFilter(sources=[SeedDatasetSourceType.LOCAL, SeedDatasetSourceType.REMOTE]) - assert SeedDatasetSourceType.LOCAL in f.sources - assert SeedDatasetSourceType.REMOTE in f.sources + f = SeedDatasetFilter(source_types=[SeedDatasetSourceType.LOCAL, SeedDatasetSourceType.REMOTE]) + assert SeedDatasetSourceType.LOCAL in f.source_types + assert SeedDatasetSourceType.REMOTE in f.source_types def test_modalities_values(self): f = SeedDatasetFilter(modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE]) diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index b24d8b56b7..7095ed57cd 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -2,9 +2,12 @@ # Licensed under the MIT license. import textwrap +from dataclasses import fields as dc_fields +from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest +import yaml from pyrit.datasets import SeedDatasetProvider from pyrit.datasets.seed_datasets.local.local_dataset_loader import _LocalDatasetLoader @@ -261,7 +264,7 @@ def test_parse_metadata_from_class_attrs(self): assert metadata.modalities == [SeedDatasetModality.TEXT] assert metadata.harm_categories == ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] # source and rank are not declared as class attributes on HarmBench - assert metadata.source is None + assert metadata.source_type is None assert metadata.rank is None def test_all_tag(self): @@ -302,14 +305,14 @@ def test_modalities(self): def test_sources(self): """Source filter checks membership.""" - metadata = SeedDatasetMetadata(source=SeedDatasetSourceType.REMOTE) + metadata = SeedDatasetMetadata(source_type=SeedDatasetSourceType.REMOTE) assert SeedDatasetProvider._match_filter( metadata=metadata, - filters=SeedDatasetFilter(sources=[SeedDatasetSourceType.REMOTE]), + filters=SeedDatasetFilter(source_types=[SeedDatasetSourceType.REMOTE]), ) assert not SeedDatasetProvider._match_filter( metadata=metadata, - filters=SeedDatasetFilter(sources=[SeedDatasetSourceType.LOCAL]), + filters=SeedDatasetFilter(source_types=[SeedDatasetSourceType.LOCAL]), ) def test_ranks(self): @@ -472,13 +475,13 @@ def test_modalities(self, tmp_path): assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) def test_sources(self, tmp_path): - """YAML produces source as string; _match_filter compares against enum values.""" + """YAML produces source_type as string; _match_filter compares against enum values.""" yaml_path = self._write_yaml( tmp_path, "test", textwrap.dedent("""\ dataset_name: test - source: remote + source_type: remote seeds: - value: test prompt data_type: text @@ -487,7 +490,7 @@ def test_sources(self, tmp_path): loader = self._make_loader(yaml_path) metadata = loader._parse_metadata() assert metadata is not None - filters = SeedDatasetFilter(sources=[SeedDatasetSourceType.REMOTE]) + filters = SeedDatasetFilter(source_types=[SeedDatasetSourceType.REMOTE]) assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) def test_ranks(self, tmp_path): @@ -565,3 +568,81 @@ def test_no_metadata(self, tmp_path): loader = self._make_loader(yaml_path) metadata = loader._parse_metadata() assert metadata is None + + +class TestLocalDatasetMetadataCollisions: + """ + Regression tests that scan every real .prompt file under seed_datasets/local + to verify _parse_metadata does not crash from field-name collisions between + the YAML schema and SeedDatasetMetadata. + + The previous `source` field collision (URLs parsed as SeedDatasetSourceType) + is the motivating example. + """ + + @staticmethod + def _get_local_prompt_files() -> list: + """Collect all .prompt and .yaml files under the local datasets directory.""" + local_dir = Path(__file__).resolve().parents[3] / "pyrit" / "datasets" / "seed_datasets" / "local" + return sorted(local_dir.glob("**/*.prompt")) + sorted(local_dir.glob("**/*.yaml")) + + @pytest.mark.parametrize("prompt_file", _get_local_prompt_files.__func__(), ids=lambda p: p.stem) + def test_parse_metadata_does_not_crash(self, prompt_file): + """_parse_metadata must not raise on any real local dataset file.""" + loader = _LocalDatasetLoader.__new__(_LocalDatasetLoader) + loader.file_path = prompt_file + loader._dataset_name = prompt_file.stem + + # This must not raise — if a YAML key collides with a metadata field + # name but holds an incompatible value, the coercion layer should + # either handle it or skip it gracefully. + metadata = loader._parse_metadata() + # metadata can be None (no matching fields) or a valid SeedDatasetMetadata + if metadata is not None: + assert isinstance(metadata, SeedDatasetMetadata) + + @pytest.mark.parametrize("prompt_file", _get_local_prompt_files.__func__(), ids=lambda p: p.stem) + def test_no_yaml_key_shadows_metadata_field_with_wrong_type(self, prompt_file): + """ + If a YAML top-level key matches a SeedDatasetMetadata field name, the + coerced value must be the correct type (enum, set, list) — not a raw + string or other primitive that would silently break filtering. + """ + with open(prompt_file, encoding="utf-8") as f: + data = yaml.safe_load(f) + + if not isinstance(data, dict): + return + + metadata_field_names = {fld.name for fld in dc_fields(SeedDatasetMetadata)} + overlapping_keys = metadata_field_names & data.keys() + + if not overlapping_keys: + return + + # Coerce and construct — must not raise + loader = _LocalDatasetLoader.__new__(_LocalDatasetLoader) + loader.file_path = prompt_file + loader._dataset_name = prompt_file.stem + + raw = {k: data[k] for k in overlapping_keys} + coerced = _LocalDatasetLoader._coerce_metadata_values(raw_metadata=raw) + metadata = SeedDatasetMetadata(**coerced) + + # Verify coerced types match expectations + expected_types = { + "tags": (set, type(None)), + "size": (SeedDatasetSize, type(None)), + "modalities": (list, type(None)), + "source_type": (SeedDatasetSourceType, type(None)), + "rank": (SeedDatasetLoadingRank, type(None)), + "harm_categories": (list, type(None)), + } + for key in overlapping_keys: + value = getattr(metadata, key) + valid_types = expected_types.get(key) + if valid_types: + assert isinstance(value, valid_types), ( + f"Field '{key}' in {prompt_file.name} has type {type(value).__name__}, " + f"expected one of {valid_types}" + ) From c94a6da2056ec221d5a2dc69d6e50e94f0fd1917 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 13 Mar 2026 20:43:10 +0000 Subject: [PATCH 08/29] precommit --- pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py b/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py index b2b45c2a33..a622a4a018 100644 --- a/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py @@ -23,7 +23,7 @@ class _JBBBehaviorsDataset(_RemoteDatasetLoader): and may contain offensive content. Users should check with their legal department before using these prompts against production LLMs. """ - + def __init__( self, *, From 2e7e9375735f702761e38d739766faddaefba00a Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Sat, 14 Mar 2026 00:03:13 +0000 Subject: [PATCH 09/29] utilities scaffolding --- .../seed_datasets/seed_dataset_provider.py | 3 +- pyrit/datasets/seed_datasets/seed_metadata.py | 119 +++++++++++++++++- .../datasets/test_seed_dataset_metadata.py | 8 +- .../datasets/test_seed_dataset_provider.py | 7 +- 4 files changed, 125 insertions(+), 12 deletions(-) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 4d65b932c2..f5ef0d2736 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -9,7 +9,7 @@ from tqdm import tqdm -from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetFilter, SeedDatasetMetadata +from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetFilter, SeedDatasetLoadingRank, SeedDatasetMetadata from pyrit.models.seeds import SeedDataset logger = logging.getLogger(__name__) @@ -33,6 +33,7 @@ class SeedDatasetProvider(ABC): """ _registry: dict[str, type["SeedDatasetProvider"]] = {} + rank: SeedDatasetLoadingRank = SeedDatasetLoadingRank.UNKNOWN def __init_subclass__(cls, **kwargs: Any) -> None: """ diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 01c97b8934..d203d7e1cf 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Optional +from typing import Optional, TypedDict """ Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). @@ -33,9 +33,18 @@ class SeedDatasetLoadingRank(Enum): Represents the general difficulty of loading in a dataset. """ + # Default is equivalent to "fastest" in the sense that datasets marked + # with a default rank will always get loaded. DEFAULT = "default" - EXTENDED = "extended" - SLOW = "slow" + + # These represent actual ranks. + PRIMARY = "primary" + SECONDARY = "secondary" + TERTIARY = "tertiary" + + # Unknown corresponds to an untested dataset that won't be loaded. It is the + # default provided in SeedDatasetProvider. + UNKNOWN = "unknown" class SeedDatasetModality(Enum): @@ -84,5 +93,107 @@ class SeedDatasetMetadata: size: Optional[SeedDatasetSize] = None modalities: Optional[list[SeedDatasetModality]] = None source_type: Optional[SeedDatasetSourceType] = None - rank: Optional[SeedDatasetLoadingRank] = None + rank: SeedDatasetLoadingRank = SeedDatasetLoadingRank.UNKNOWN harm_categories: Optional[list[str]] = None + + +class SeedDatasetMetadataUtilities: + """ + Utilities for deriving metadata for datasets. Currently, only static attributes + are supported. + + The default working location for datasets is the in-memory database. + """ + + class Metrics(TypedDict): + """ + Typed dictionary for easier retrieval and calculation of dataset metrics. + """ + + exact_size: int + loading_time_ms: float + modalities_found: set[str] + source_type: str + harm_categories_found: set[str] + tags: set[str] + + # Stores working dataset calculations. + # Maps name to metrics, which are later converted into SeedDatasetMetadata. + _cache: dict[str, Metrics] = {} + + @classmethod + def populate_datasets(cls) -> None: + """ + Populate metadata for all registered datasets. + + WARNING: Because metadata is stored as class attributes, this method can directly + change source files. Be extra careful when running it. + """ + # Get all dataset names + # Calling SeedDatasetProvider would create a circular import, so we do this explicitly + datasets: list[str] = [] + + # Populate cache with empty (name, metrics) pairs + for dataset in datasets: + metrics: SeedDatasetMetadataUtilities.Metrics = { + "exact_size": -1, + "loading_time_ms": -1.0, + "modalities_found": {"None"}, + "source_type": "None", + "harm_categories_found": {"None"}, + "tags": {"None"}, + } + cls._cache[dataset] = metrics + + # Using a list, for each dataset name, load it in depending on class type + # Invoke the appropriate helper to parse it + + # If local, local_helper + + # If remote, remote_helper + + # Get contents from the memory database + # Note that we have to load it into the memory_database to get timing + # We also want the helper to do no initialization, just extract the relevant + # types and get ready to call a timing library + + # Calculate metrics one by one + + # Once out of the loop, calculate metadata fields + + # Loading rank by comparing relative speeds + + # Size by comparing buckets + + # Convert all others to types + + # Update (if update = True) the datasets + + # If remote, write to the file using regex + # E.g. harm_categories: ... should appear in source + + # If local, make sure the .prompt is formatted nicely + + @classmethod + def _local_helper(cls) -> None: + """ + Load local datasets into the working cache. + """ + + @classmethod + def _remote_helper(cls) -> None: + """ + Load remote datasets into the working cache. + """ + + @classmethod + def _remote_writer(cls) -> None: + """ + Write updated metadata to a remote dataset source file. + """ + + @classmethod + def _local_writer(cls) -> None: + """ + Write updated metadata to a local .prompt file. + """ diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py index 4aaaed1fd4..73fa3dad3e 100644 --- a/tests/unit/datasets/test_seed_dataset_metadata.py +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -27,7 +27,7 @@ def test_has_no_values(self): assert metadata.size is None assert metadata.modalities is None assert metadata.source_type is None - assert metadata.rank is None + assert metadata.rank == SeedDatasetLoadingRank.UNKNOWN assert metadata.harm_categories is None def test_has_some_values(self): @@ -36,7 +36,7 @@ def test_has_some_values(self): assert metadata.size == SeedDatasetSize.LARGE assert metadata.modalities is None assert metadata.source_type is None - assert metadata.rank is None + assert metadata.rank == SeedDatasetLoadingRank.UNKNOWN assert metadata.harm_categories is None def test_has_all_values(self): @@ -142,9 +142,9 @@ def test_sizes_values(self): assert SeedDatasetSize.LARGE in f.sizes def test_loading_ranks_values(self): - f = SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT, SeedDatasetLoadingRank.SLOW]) + f = SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT, SeedDatasetLoadingRank.TERTIARY]) assert SeedDatasetLoadingRank.DEFAULT in f.ranks - assert SeedDatasetLoadingRank.SLOW in f.ranks + assert SeedDatasetLoadingRank.TERTIARY in f.ranks def test_sources_values(self): f = SeedDatasetFilter(source_types=[SeedDatasetSourceType.LOCAL, SeedDatasetSourceType.REMOTE]) diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index 7095ed57cd..52029850b7 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -263,9 +263,10 @@ def test_parse_metadata_from_class_attrs(self): assert metadata.size == SeedDatasetSize.LARGE assert metadata.modalities == [SeedDatasetModality.TEXT] assert metadata.harm_categories == ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] - # source and rank are not declared as class attributes on HarmBench + # source_type is not declared as a class attribute on HarmBench; + # rank inherits the UNKNOWN default from SeedDatasetProvider base class assert metadata.source_type is None - assert metadata.rank is None + assert metadata.rank == SeedDatasetLoadingRank.UNKNOWN def test_all_tag(self): """Filter with tags={'all'} matches any metadata.""" @@ -324,7 +325,7 @@ def test_ranks(self): ) assert not SeedDatasetProvider._match_filter( metadata=metadata, - filters=SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.SLOW]), + filters=SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.TERTIARY]), ) def test_harm_categories(self): From 3c926025904f8b25f6c8fc5d7a23916034955bfc Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Mon, 16 Mar 2026 19:11:20 +0000 Subject: [PATCH 10/29] . --- .../local/local_dataset_loader.py | 27 ++++++++++++------- .../seed_datasets/seed_dataset_provider.py | 20 +++++++++----- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py index 1ef745f628..3682182a4a 100644 --- a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py @@ -46,7 +46,8 @@ def __init__(self, *, file_path: Path): dataset = SeedDataset.from_yaml_file(file_path) # Use the dataset_name from the YAML if available, otherwise use filename self._dataset_name = ( - getattr(dataset, "dataset_name", None) or getattr(dataset, "name", None) or file_path.stem + getattr(dataset, "dataset_name", None) or getattr( + dataset, "name", None) or file_path.stem ) except Exception as e: logger.warning(f"Could not pre-load dataset from {file_path}: {e}") @@ -77,7 +78,8 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: dataset.dataset_name = self.dataset_name return dataset except Exception as e: - logger.error(f"Failed to load local dataset from {self.file_path}: {e}") + logger.error( + f"Failed to load local dataset from {self.file_path}: {e}") raise def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: @@ -98,11 +100,11 @@ def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: with open(self.file_path, encoding="utf-8") as f: dataset = yaml.safe_load(f) except Exception as e: - logger.error(f"Failed to load local dataset from {self.file_path}: {e}") + logger.error( + f"Failed to load local dataset from {self.file_path}: {e}") raise - if not isinstance(dataset, dict): - return None + dataset = self. raw = {k: v for k, v in dataset.items() if k in valid_fields} if not raw: @@ -165,21 +167,26 @@ def _register_local_datasets() -> None: def make_init(path: Path) -> Callable[[Any], None]: def __init__(self: Any) -> None: # noqa: N807 - super(self.__class__, self).__init__(file_path=path) + super(self.__class__, self).__init__( + file_path=path) return __init__ type( class_name, (_LocalDatasetLoader,), - {"__init__": make_init(yaml_file), "should_register": True, "__module__": __name__}, + {"__init__": make_init( + yaml_file), "should_register": True, "__module__": __name__}, ) - logger.debug(f"Registered local dataset loader: {class_name} for {yaml_file.name}") + logger.debug( + f"Registered local dataset loader: {class_name} for {yaml_file.name}") except Exception as e: - logger.warning(f"Failed to register local dataset {yaml_file}: {e}") + logger.warning( + f"Failed to register local dataset {yaml_file}: {e}") else: - logger.warning(f"Seed datasets directory not found: {seed_datasets_path}") + logger.warning( + f"Seed datasets directory not found: {seed_datasets_path}") # Execute registration diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index f5ef0d2736..22c8b0089f 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -138,7 +138,8 @@ def get_all_dataset_names(cls, filters: Optional[SeedDatasetFilter] = None) -> l dataset_names.add(provider.dataset_name) except Exception as e: - raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e + raise ValueError( + f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) @classmethod @@ -156,7 +157,7 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter filters (SeedDatasetFilter): The filter object provided by the user to get_all_dataset_names. Returns: - bool: Whether or not the filters match or not. + bool: Whether or not the filters match. """ # Tags if filters.tags and "all" in filters.tags: @@ -234,9 +235,11 @@ async def fetch_datasets_async( # Validate dataset names if specified if dataset_names is not None: available_names = cls.get_all_dataset_names() - invalid_names = [name for name in dataset_names if name not in available_names] + invalid_names = [ + name for name in dataset_names if name not in available_names] if invalid_names: - raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") + raise ValueError( + f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -262,7 +265,8 @@ async def fetch_single_dataset( # Progress tracking total_count = len(cls._registry) - pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") + pbar = tqdm(total=total_count, + desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -300,10 +304,12 @@ async def fetch_with_semaphore( logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] - combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) + combined_seeds = list( + existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset - logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") + logger.info( + f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values()) From 2224fd2ff402413e32366292e8721e5807b43df9 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Mon, 16 Mar 2026 21:04:30 +0000 Subject: [PATCH 11/29] . --- pyrit/datasets/__init__.py | 10 +- .../local/local_dataset_loader.py | 67 +++--- .../seed_datasets/remote/harmbench_dataset.py | 9 +- pyrit/datasets/seed_datasets/seed_metadata.py | 192 ++++-------------- .../test_seed_dataset_provider_integration.py | 40 ++-- .../datasets/test_seed_dataset_metadata.py | 87 ++++---- .../datasets/test_seed_dataset_provider.py | 59 +++--- 7 files changed, 165 insertions(+), 299 deletions(-) diff --git a/pyrit/datasets/__init__.py b/pyrit/datasets/__init__.py index 8f4b543238..35b6ba9f53 100644 --- a/pyrit/datasets/__init__.py +++ b/pyrit/datasets/__init__.py @@ -10,20 +10,14 @@ from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider from pyrit.datasets.seed_datasets.seed_metadata import ( SeedDatasetFilter, - SeedDatasetLoadingRank, + SeedDatasetLoadTime, SeedDatasetMetadata, - SeedDatasetModality, - SeedDatasetSize, - SeedDatasetSourceType, ) __all__ = [ "SeedDatasetFilter", "SeedDatasetMetadata", - "SeedDatasetLoadingRank", - "SeedDatasetModality", - "SeedDatasetSize", - "SeedDatasetSourceType", + "SeedDatasetLoadTime", "SeedDatasetProvider", "TextJailBreak", ] diff --git a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py index 3682182a4a..350ee68830 100644 --- a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py @@ -11,11 +11,8 @@ from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider from pyrit.datasets.seed_datasets.seed_metadata import ( - SeedDatasetLoadingRank, + SeedDatasetLoadTime, SeedDatasetMetadata, - SeedDatasetModality, - SeedDatasetSize, - SeedDatasetSourceType, ) from pyrit.models import SeedDataset @@ -46,8 +43,7 @@ def __init__(self, *, file_path: Path): dataset = SeedDataset.from_yaml_file(file_path) # Use the dataset_name from the YAML if available, otherwise use filename self._dataset_name = ( - getattr(dataset, "dataset_name", None) or getattr( - dataset, "name", None) or file_path.stem + getattr(dataset, "dataset_name", None) or getattr(dataset, "name", None) or file_path.stem ) except Exception as e: logger.warning(f"Could not pre-load dataset from {file_path}: {e}") @@ -78,8 +74,7 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: dataset.dataset_name = self.dataset_name return dataset except Exception as e: - logger.error( - f"Failed to load local dataset from {self.file_path}: {e}") + logger.error(f"Failed to load local dataset from {self.file_path}: {e}") raise def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: @@ -100,11 +95,11 @@ def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: with open(self.file_path, encoding="utf-8") as f: dataset = yaml.safe_load(f) except Exception as e: - logger.error( - f"Failed to load local dataset from {self.file_path}: {e}") + logger.error(f"Failed to load local dataset from {self.file_path}: {e}") raise - dataset = self. + if not isinstance(dataset, dict): + return None raw = {k: v for k, v in dataset.items() if k in valid_fields} if not raw: @@ -116,7 +111,10 @@ def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: @staticmethod def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: """ - Convert YAML primitive values into the enum/set types expected by SeedDatasetMetadata. + Convert YAML primitive values into the types expected by SeedDatasetMetadata. + + Applies .lower().strip() normalization to string values for size, modalities, + source_type, and harm_categories to prevent case/whitespace mismatches. Args: raw_metadata (dict[str, Any]): Dictionary of field names to raw YAML-parsed values. @@ -127,21 +125,27 @@ def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: coerced: dict[str, Any] = {} for key, value in raw_metadata.items(): if key == "tags" and isinstance(value, list): - coerced[key] = set(value) - elif key == "size" and isinstance(value, str): - coerced[key] = SeedDatasetSize(value) - elif key == "source_type" and isinstance(value, str): - coerced[key] = SeedDatasetSourceType(value) - elif key == "rank" and isinstance(value, str): - coerced[key] = SeedDatasetLoadingRank(value) + coerced[key] = {v.strip().lower() if isinstance(v, str) else v for v in value} + elif key == "tags" and isinstance(value, str): + coerced[key] = {value.strip().lower()} + elif key == "size" and isinstance(value, str) or key == "source_type" and isinstance(value, str): + coerced[key] = value.strip().lower() + elif key == "load_time" and isinstance(value, str): + coerced[key] = SeedDatasetLoadTime(value.strip().lower()) elif key == "modalities" and isinstance(value, list): - coerced[key] = [SeedDatasetModality(v) for v in value] + coerced[key] = [v.strip().lower() if isinstance(v, str) else v for v in value] + elif key == "modalities" and isinstance(value, str): + coerced[key] = [value.strip().lower()] + elif key == "harm_categories" and isinstance(value, list): + coerced[key] = [v.strip().lower() if isinstance(v, str) else v for v in value] elif key == "harm_categories" and isinstance(value, str): - coerced[key] = [value] - elif key == "tags" and isinstance(value, str): - coerced[key] = {value} + coerced[key] = [value.strip().lower()] else: - coerced[key] = value + # Unexpected type for a metadata field — skip it with a warning + # rather than passing garbage into SeedDatasetMetadata. + logger.warning( + f"Skipping metadata field '{key}' with unexpected type {type(value).__name__} (value: {value!r})" + ) return coerced @@ -167,26 +171,21 @@ def _register_local_datasets() -> None: def make_init(path: Path) -> Callable[[Any], None]: def __init__(self: Any) -> None: # noqa: N807 - super(self.__class__, self).__init__( - file_path=path) + super(self.__class__, self).__init__(file_path=path) return __init__ type( class_name, (_LocalDatasetLoader,), - {"__init__": make_init( - yaml_file), "should_register": True, "__module__": __name__}, + {"__init__": make_init(yaml_file), "should_register": True, "__module__": __name__}, ) - logger.debug( - f"Registered local dataset loader: {class_name} for {yaml_file.name}") + logger.debug(f"Registered local dataset loader: {class_name} for {yaml_file.name}") except Exception as e: - logger.warning( - f"Failed to register local dataset {yaml_file}: {e}") + logger.warning(f"Failed to register local dataset {yaml_file}: {e}") else: - logger.warning( - f"Seed datasets directory not found: {seed_datasets_path}") + logger.warning(f"Seed datasets directory not found: {seed_datasets_path}") # Execute registration diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py index 5d30fac2c4..a1e59f68e1 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py @@ -6,10 +6,6 @@ from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) -from pyrit.datasets.seed_datasets.seed_metadata import ( - SeedDatasetModality, - SeedDatasetSize, -) from pyrit.models import SeedDataset, SeedObjective @@ -25,9 +21,8 @@ class _HarmBenchDataset(_RemoteDatasetLoader): # Metadata harm_categories: list[str] = ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] - modalities: list[SeedDatasetModality] = [SeedDatasetModality.TEXT] - size: SeedDatasetSize = SeedDatasetSize.LARGE # 504 seeds - # "default" means included in curated set + modalities: list[str] = ["text"] + size: str = "large" # 504 seeds tags: set[str] = {"default", "safety"} def __init__( diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index d203d7e1cf..6810329e21 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Optional, TypedDict +from typing import Literal, Optional """ Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). @@ -15,56 +15,32 @@ SeedDatasetFilter is the user-facing filter schema consumed by SeedDatasetProvider.get_all_dataset_names(). -""" - - -class SeedDatasetSize(Enum): - """Ordinal size (by bucket) of the dataset.""" - TINY = "tiny" # < 10 - SMALL = "small" # >= 10, < 100 - MEDIUM = "medium" # >= 100, < 500 - LARGE = "large" # >= 500, < 5000 - HUGE = "huge" # >= 5000 - - -class SeedDatasetLoadingRank(Enum): - """ - Represents the general difficulty of loading in a dataset. - """ +Size and modality are string literals rather than enums for usability — callers +don't need to import extra types to construct a filter. +""" - # Default is equivalent to "fastest" in the sense that datasets marked - # with a default rank will always get loaded. - DEFAULT = "default" +# Documented expected values for string-typed metadata fields. +# These are not enforced at runtime but serve as documentation. +SeedDatasetSizeLiteral = Literal["tiny", "small", "medium", "large", "huge"] +"""tiny (<10), small (10-99), medium (100-499), large (500-4999), huge (5000+)""" - # These represent actual ranks. - PRIMARY = "primary" - SECONDARY = "secondary" - TERTIARY = "tertiary" +SeedDatasetModalityLiteral = Literal["text", "image", "video", "audio"] - # Unknown corresponds to an untested dataset that won't be loaded. It is the - # default provided in SeedDatasetProvider. - UNKNOWN = "unknown" +SeedDatasetSourceTypeLiteral = Literal["remote", "local"] -class SeedDatasetModality(Enum): +class SeedDatasetLoadTime(Enum): """ - Type of data contained in the dataset. + Approximate time to load a dataset. Used to skip slow datasets in fast runs. """ - TEXT = "text" - IMAGE = "image" - VIDEO = "video" - AUDIO = "audio" - + FAST = "fast" + NORMAL = "normal" + SLOW = "slow" -class SeedDatasetSourceType(Enum): - """ - Where the dataset is pulled from. - """ - - REMOTE = "remote" - LOCAL = "local" + # Default value for datasets whose load time hasn't been measured. + UNINITIALIZED = "uninitialized" @dataclass @@ -72,128 +48,38 @@ class SeedDatasetFilter: """ Filter object for datasets. Passed to `get_all_dataset_names` in SeedDatasetProvider. + + Most fields are optional. None means "don't filter on this axis." + + Exception for load_times, which defaults to UNINITIALIZED. + + By default, filtering is OR-wise across filter categories and OR-wise within + filter categories. """ + # Tags are a top-level set of labels that assist with filtering. + # The tag "all" will return every discoverable dataset. + # The tag "default" will return every dataset with an initialized + # load_time (i.e., SeedDatasetLoadTime != UNINITIALIZED.) or an explicit + # "default" tag (think of this like a pinned or starred item). tags: Optional[set[str]] = None - sizes: Optional[list[SeedDatasetSize]] = None - modalities: Optional[list[SeedDatasetModality]] = None - source_types: Optional[list[SeedDatasetSourceType]] = None - ranks: Optional[list[SeedDatasetLoadingRank]] = None + sizes: Optional[list[str]] = None + modalities: Optional[list[str]] = None + source_types: Optional[list[str]] = None + load_times: Optional[list[SeedDatasetLoadTime]] = None harm_categories: Optional[list[str]] = None @dataclass(frozen=True) class SeedDatasetMetadata: """ - Metadata object for datasets. Holds the same fields as the filter - object. + Internal schema for dataset metadata. Constructed by _parse_metadata() + implementations on each provider type. """ tags: Optional[set[str]] = None - size: Optional[SeedDatasetSize] = None - modalities: Optional[list[SeedDatasetModality]] = None - source_type: Optional[SeedDatasetSourceType] = None - rank: SeedDatasetLoadingRank = SeedDatasetLoadingRank.UNKNOWN + size: Optional[str] = None + modalities: Optional[list[str]] = None + source_type: Optional[str] = None + load_time: SeedDatasetLoadTime = SeedDatasetLoadTime.UNINITIALIZED harm_categories: Optional[list[str]] = None - - -class SeedDatasetMetadataUtilities: - """ - Utilities for deriving metadata for datasets. Currently, only static attributes - are supported. - - The default working location for datasets is the in-memory database. - """ - - class Metrics(TypedDict): - """ - Typed dictionary for easier retrieval and calculation of dataset metrics. - """ - - exact_size: int - loading_time_ms: float - modalities_found: set[str] - source_type: str - harm_categories_found: set[str] - tags: set[str] - - # Stores working dataset calculations. - # Maps name to metrics, which are later converted into SeedDatasetMetadata. - _cache: dict[str, Metrics] = {} - - @classmethod - def populate_datasets(cls) -> None: - """ - Populate metadata for all registered datasets. - - WARNING: Because metadata is stored as class attributes, this method can directly - change source files. Be extra careful when running it. - """ - # Get all dataset names - # Calling SeedDatasetProvider would create a circular import, so we do this explicitly - datasets: list[str] = [] - - # Populate cache with empty (name, metrics) pairs - for dataset in datasets: - metrics: SeedDatasetMetadataUtilities.Metrics = { - "exact_size": -1, - "loading_time_ms": -1.0, - "modalities_found": {"None"}, - "source_type": "None", - "harm_categories_found": {"None"}, - "tags": {"None"}, - } - cls._cache[dataset] = metrics - - # Using a list, for each dataset name, load it in depending on class type - # Invoke the appropriate helper to parse it - - # If local, local_helper - - # If remote, remote_helper - - # Get contents from the memory database - # Note that we have to load it into the memory_database to get timing - # We also want the helper to do no initialization, just extract the relevant - # types and get ready to call a timing library - - # Calculate metrics one by one - - # Once out of the loop, calculate metadata fields - - # Loading rank by comparing relative speeds - - # Size by comparing buckets - - # Convert all others to types - - # Update (if update = True) the datasets - - # If remote, write to the file using regex - # E.g. harm_categories: ... should appear in source - - # If local, make sure the .prompt is formatted nicely - - @classmethod - def _local_helper(cls) -> None: - """ - Load local datasets into the working cache. - """ - - @classmethod - def _remote_helper(cls) -> None: - """ - Load remote datasets into the working cache. - """ - - @classmethod - def _remote_writer(cls) -> None: - """ - Write updated metadata to a remote dataset source file. - """ - - @classmethod - def _local_writer(cls) -> None: - """ - Write updated metadata to a local .prompt file. - """ diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index f2da3a292a..fdd58ec97f 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -13,8 +13,6 @@ from pyrit.datasets.seed_datasets.remote import _VLSUMultimodalDataset from pyrit.datasets.seed_datasets.seed_metadata import ( SeedDatasetFilter, - SeedDatasetModality, - SeedDatasetSize, ) from pyrit.models import SeedDataset, SeedPrompt @@ -80,7 +78,7 @@ def _make_remote_provider_cls( *, name: str, tags: set, - size: SeedDatasetSize, + size: str, modalities: list, harm_categories: list, ) -> type: @@ -115,15 +113,15 @@ def test_filter_matches_correct_remote_provider(self): large_cls = self._make_remote_provider_cls( name="large_ds", tags={"default"}, - size=SeedDatasetSize.LARGE, - modalities=[SeedDatasetModality.TEXT], + size="large", + modalities=["text"], harm_categories=["violence"], ) small_cls = self._make_remote_provider_cls( name="small_ds", tags={"default"}, - size=SeedDatasetSize.SMALL, - modalities=[SeedDatasetModality.TEXT], + size="small", + modalities=["text"], harm_categories=["cybercrime"], ) @@ -133,7 +131,7 @@ def test_filter_matches_correct_remote_provider(self): clear=True, ): names = SeedDatasetProvider.get_all_dataset_names( - filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]), + filters=SeedDatasetFilter(sizes=["large"]), ) assert names == ["large_ds"] @@ -142,15 +140,15 @@ def test_filter_all_tag_returns_everything(self): cls1 = self._make_remote_provider_cls( name="ds_a", tags={"safety"}, - size=SeedDatasetSize.TINY, - modalities=[SeedDatasetModality.TEXT], + size="tiny", + modalities=["text"], harm_categories=[], ) cls2 = self._make_remote_provider_cls( name="ds_b", tags={"custom"}, - size=SeedDatasetSize.HUGE, - modalities=[SeedDatasetModality.IMAGE], + size="huge", + modalities=["image"], harm_categories=["violence"], ) @@ -169,15 +167,15 @@ def test_multi_axis_filter(self): cls1 = self._make_remote_provider_cls( name="text_large", tags={"default"}, - size=SeedDatasetSize.LARGE, - modalities=[SeedDatasetModality.TEXT], + size="large", + modalities=["text"], harm_categories=["violence"], ) cls2 = self._make_remote_provider_cls( name="image_large", tags={"default"}, - size=SeedDatasetSize.LARGE, - modalities=[SeedDatasetModality.IMAGE], + size="large", + modalities=["image"], harm_categories=["violence"], ) @@ -188,8 +186,8 @@ def test_multi_axis_filter(self): ): names = SeedDatasetProvider.get_all_dataset_names( filters=SeedDatasetFilter( - sizes=[SeedDatasetSize.LARGE], - modalities=[SeedDatasetModality.TEXT], + sizes=["large"], + modalities=["text"], ), ) assert names == ["text_large"] @@ -255,7 +253,7 @@ def test_local_filter_by_size(self, tmp_path): clear=True, ): names = SeedDatasetProvider.get_all_dataset_names( - filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]), + filters=SeedDatasetFilter(sizes=["large"]), ) # dataset_name falls back to file stem when SeedDataset.from_yaml_file # rejects extra keys like "size" during __init__ pre-loading @@ -541,13 +539,13 @@ def test_all_tag_ignores_other_filter_axes(self, tmp_path): ): # Size filter alone would exclude it size_filtered = SeedDatasetProvider.get_all_dataset_names( - filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]), + filters=SeedDatasetFilter(sizes=["large"]), ) assert size_filtered == [] # 'all' tag overrides the size filter all_names = SeedDatasetProvider.get_all_dataset_names( - filters=SeedDatasetFilter(tags={"all"}, sizes=[SeedDatasetSize.LARGE]), + filters=SeedDatasetFilter(tags={"all"}, sizes=["large"]), ) assert "small" in all_names diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py index 73fa3dad3e..e9096a470b 100644 --- a/tests/unit/datasets/test_seed_dataset_metadata.py +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -7,11 +7,8 @@ from pyrit.datasets.seed_datasets.seed_metadata import ( SeedDatasetFilter, - SeedDatasetLoadingRank, + SeedDatasetLoadTime, SeedDatasetMetadata, - SeedDatasetModality, - SeedDatasetSize, - SeedDatasetSourceType, ) @@ -27,32 +24,32 @@ def test_has_no_values(self): assert metadata.size is None assert metadata.modalities is None assert metadata.source_type is None - assert metadata.rank == SeedDatasetLoadingRank.UNKNOWN + assert metadata.load_time == SeedDatasetLoadTime.UNINITIALIZED assert metadata.harm_categories is None def test_has_some_values(self): - metadata = SeedDatasetMetadata(tags={"safety"}, size=SeedDatasetSize.LARGE) + metadata = SeedDatasetMetadata(tags={"safety"}, size="large") assert metadata.tags == {"safety"} - assert metadata.size == SeedDatasetSize.LARGE + assert metadata.size == "large" assert metadata.modalities is None assert metadata.source_type is None - assert metadata.rank == SeedDatasetLoadingRank.UNKNOWN + assert metadata.load_time == SeedDatasetLoadTime.UNINITIALIZED assert metadata.harm_categories is None def test_has_all_values(self): metadata = SeedDatasetMetadata( tags={"default", "safety"}, - size=SeedDatasetSize.MEDIUM, - modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE], - source_type=SeedDatasetSourceType.REMOTE, - rank=SeedDatasetLoadingRank.DEFAULT, + size="medium", + modalities=["text", "image"], + source_type="remote", + load_time=SeedDatasetLoadTime.FAST, harm_categories=["violence", "illegal"], ) assert metadata.tags == {"default", "safety"} - assert metadata.size == SeedDatasetSize.MEDIUM + assert metadata.size == "medium" assert len(metadata.modalities) == 2 - assert metadata.source_type == SeedDatasetSourceType.REMOTE - assert metadata.rank == SeedDatasetLoadingRank.DEFAULT + assert metadata.source_type == "remote" + assert metadata.load_time == SeedDatasetLoadTime.FAST assert metadata.harm_categories == ["violence", "illegal"] @@ -68,29 +65,29 @@ def test_has_no_values(self): assert f.sizes is None assert f.modalities is None assert f.source_types is None - assert f.ranks is None + assert f.load_times is None assert f.harm_categories is None def test_has_some_values(self): - f = SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]) - assert f.sizes == [SeedDatasetSize.LARGE] + f = SeedDatasetFilter(sizes=["large"]) + assert f.sizes == ["large"] assert f.tags is None assert f.modalities is None def test_has_all_values(self): f = SeedDatasetFilter( tags={"default"}, - sizes=[SeedDatasetSize.SMALL, SeedDatasetSize.MEDIUM], - modalities=[SeedDatasetModality.TEXT], - source_types=[SeedDatasetSourceType.REMOTE], - ranks=[SeedDatasetLoadingRank.DEFAULT], + sizes=["small", "medium"], + modalities=["text"], + source_types=["remote"], + load_times=[SeedDatasetLoadTime.FAST], harm_categories=["violence"], ) assert f.tags == {"default"} assert len(f.sizes) == 2 - assert f.modalities == [SeedDatasetModality.TEXT] - assert f.source_types == [SeedDatasetSourceType.REMOTE] - assert f.ranks == [SeedDatasetLoadingRank.DEFAULT] + assert f.modalities == ["text"] + assert f.source_types == ["remote"] + assert f.load_times == [SeedDatasetLoadTime.FAST] assert f.harm_categories == ["violence"] @@ -100,22 +97,22 @@ class TestMetadataProperties: """ def test_size_value(self): - for size in SeedDatasetSize: + for size in ["tiny", "small", "medium", "large", "huge"]: metadata = SeedDatasetMetadata(size=size) assert metadata.size == size - def test_loading_rank_value(self): - for rank in SeedDatasetLoadingRank: - metadata = SeedDatasetMetadata(rank=rank) - assert metadata.rank == rank + def test_load_time_value(self): + for lt in SeedDatasetLoadTime: + metadata = SeedDatasetMetadata(load_time=lt) + assert metadata.load_time == lt def test_source_value(self): - for source_type in SeedDatasetSourceType: + for source_type in ["remote", "local"]: metadata = SeedDatasetMetadata(source_type=source_type) assert metadata.source_type == source_type def test_modality_value(self): - for modality in SeedDatasetModality: + for modality in ["text", "image", "video", "audio"]: metadata = SeedDatasetMetadata(modalities=[modality]) assert modality in metadata.modalities @@ -137,24 +134,24 @@ class TestFilterProperties: """ def test_sizes_values(self): - f = SeedDatasetFilter(sizes=[SeedDatasetSize.SMALL, SeedDatasetSize.LARGE]) - assert SeedDatasetSize.SMALL in f.sizes - assert SeedDatasetSize.LARGE in f.sizes + f = SeedDatasetFilter(sizes=["small", "large"]) + assert "small" in f.sizes + assert "large" in f.sizes - def test_loading_ranks_values(self): - f = SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT, SeedDatasetLoadingRank.TERTIARY]) - assert SeedDatasetLoadingRank.DEFAULT in f.ranks - assert SeedDatasetLoadingRank.TERTIARY in f.ranks + def test_load_times_values(self): + f = SeedDatasetFilter(load_times=[SeedDatasetLoadTime.FAST, SeedDatasetLoadTime.SLOW]) + assert SeedDatasetLoadTime.FAST in f.load_times + assert SeedDatasetLoadTime.SLOW in f.load_times def test_sources_values(self): - f = SeedDatasetFilter(source_types=[SeedDatasetSourceType.LOCAL, SeedDatasetSourceType.REMOTE]) - assert SeedDatasetSourceType.LOCAL in f.source_types - assert SeedDatasetSourceType.REMOTE in f.source_types + f = SeedDatasetFilter(source_types=["local", "remote"]) + assert "local" in f.source_types + assert "remote" in f.source_types def test_modalities_values(self): - f = SeedDatasetFilter(modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE]) - assert SeedDatasetModality.TEXT in f.modalities - assert SeedDatasetModality.IMAGE in f.modalities + f = SeedDatasetFilter(modalities=["text", "image"]) + assert "text" in f.modalities + assert "image" in f.modalities def test_tags_values(self): f = SeedDatasetFilter(tags={"safety", "default"}) diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index 52029850b7..bf7a54a3ca 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -15,11 +15,8 @@ from pyrit.datasets.seed_datasets.remote.harmbench_dataset import _HarmBenchDataset from pyrit.datasets.seed_datasets.seed_metadata import ( SeedDatasetFilter, - SeedDatasetLoadingRank, + SeedDatasetLoadTime, SeedDatasetMetadata, - SeedDatasetModality, - SeedDatasetSize, - SeedDatasetSourceType, ) from pyrit.models import SeedDataset, SeedObjective, SeedPrompt @@ -260,13 +257,13 @@ def test_parse_metadata_from_class_attrs(self): metadata = loader._parse_metadata() assert metadata is not None assert metadata.tags == {"default", "safety"} - assert metadata.size == SeedDatasetSize.LARGE - assert metadata.modalities == [SeedDatasetModality.TEXT] + assert metadata.size == "large" + assert metadata.modalities == ["text"] assert metadata.harm_categories == ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] # source_type is not declared as a class attribute on HarmBench; - # rank inherits the UNKNOWN default from SeedDatasetProvider base class + # load_time inherits the UNINITIALIZED default from SeedDatasetProvider base class assert metadata.source_type is None - assert metadata.rank == SeedDatasetLoadingRank.UNKNOWN + assert metadata.load_time == SeedDatasetLoadTime.UNINITIALIZED def test_all_tag(self): """Filter with tags={'all'} matches any metadata.""" @@ -282,50 +279,50 @@ def test_tags(self): def test_sizes(self): """Size filter checks membership in the sizes list.""" - metadata = SeedDatasetMetadata(size=SeedDatasetSize.LARGE) + metadata = SeedDatasetMetadata(size="large") assert SeedDatasetProvider._match_filter( metadata=metadata, - filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE, SeedDatasetSize.HUGE]), + filters=SeedDatasetFilter(sizes=["large", "huge"]), ) assert not SeedDatasetProvider._match_filter( metadata=metadata, - filters=SeedDatasetFilter(sizes=[SeedDatasetSize.SMALL]), + filters=SeedDatasetFilter(sizes=["small"]), ) def test_modalities(self): """Modality filter uses set intersection.""" - metadata = SeedDatasetMetadata(modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE]) + metadata = SeedDatasetMetadata(modalities=["text", "image"]) assert SeedDatasetProvider._match_filter( metadata=metadata, - filters=SeedDatasetFilter(modalities=[SeedDatasetModality.TEXT]), + filters=SeedDatasetFilter(modalities=["text"]), ) assert not SeedDatasetProvider._match_filter( metadata=metadata, - filters=SeedDatasetFilter(modalities=[SeedDatasetModality.AUDIO]), + filters=SeedDatasetFilter(modalities=["audio"]), ) def test_sources(self): """Source filter checks membership.""" - metadata = SeedDatasetMetadata(source_type=SeedDatasetSourceType.REMOTE) + metadata = SeedDatasetMetadata(source_type="remote") assert SeedDatasetProvider._match_filter( metadata=metadata, - filters=SeedDatasetFilter(source_types=[SeedDatasetSourceType.REMOTE]), + filters=SeedDatasetFilter(source_types=["remote"]), ) assert not SeedDatasetProvider._match_filter( metadata=metadata, - filters=SeedDatasetFilter(source_types=[SeedDatasetSourceType.LOCAL]), + filters=SeedDatasetFilter(source_types=["local"]), ) def test_ranks(self): - """Rank filter checks membership.""" - metadata = SeedDatasetMetadata(rank=SeedDatasetLoadingRank.DEFAULT) + """Load time filter checks membership.""" + metadata = SeedDatasetMetadata(load_time=SeedDatasetLoadTime.FAST) assert SeedDatasetProvider._match_filter( metadata=metadata, - filters=SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT]), + filters=SeedDatasetFilter(load_times=[SeedDatasetLoadTime.FAST]), ) assert not SeedDatasetProvider._match_filter( metadata=metadata, - filters=SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.TERTIARY]), + filters=SeedDatasetFilter(load_times=[SeedDatasetLoadTime.SLOW]), ) def test_harm_categories(self): @@ -342,7 +339,7 @@ def test_harm_categories(self): def test_empty_filter(self): """Empty filter (all None) matches any metadata.""" - metadata = SeedDatasetMetadata(tags={"safety"}, size=SeedDatasetSize.LARGE) + metadata = SeedDatasetMetadata(tags={"safety"}, size="large") filters = SeedDatasetFilter() assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) @@ -452,7 +449,7 @@ def test_sizes(self, tmp_path): loader = self._make_loader(yaml_path) metadata = loader._parse_metadata() assert metadata is not None - filters = SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]) + filters = SeedDatasetFilter(sizes=["large"]) assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) def test_modalities(self, tmp_path): @@ -472,7 +469,7 @@ def test_modalities(self, tmp_path): loader = self._make_loader(yaml_path) metadata = loader._parse_metadata() assert metadata is not None - filters = SeedDatasetFilter(modalities=[SeedDatasetModality.TEXT]) + filters = SeedDatasetFilter(modalities=["text"]) assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) def test_sources(self, tmp_path): @@ -491,17 +488,17 @@ def test_sources(self, tmp_path): loader = self._make_loader(yaml_path) metadata = loader._parse_metadata() assert metadata is not None - filters = SeedDatasetFilter(source_types=[SeedDatasetSourceType.REMOTE]) + filters = SeedDatasetFilter(source_types=["remote"]) assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) def test_ranks(self, tmp_path): - """YAML produces rank as string; _match_filter compares against enum values.""" + """YAML produces load_time as string; _match_filter compares against enum values.""" yaml_path = self._write_yaml( tmp_path, "test", textwrap.dedent("""\ dataset_name: test - rank: default + load_time: fast seeds: - value: test prompt data_type: text @@ -510,7 +507,7 @@ def test_ranks(self, tmp_path): loader = self._make_loader(yaml_path) metadata = loader._parse_metadata() assert metadata is not None - filters = SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT]) + filters = SeedDatasetFilter(load_times=[SeedDatasetLoadTime.FAST]) assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) def test_harm_categories(self, tmp_path): @@ -633,10 +630,10 @@ def test_no_yaml_key_shadows_metadata_field_with_wrong_type(self, prompt_file): # Verify coerced types match expectations expected_types = { "tags": (set, type(None)), - "size": (SeedDatasetSize, type(None)), + "size": (str, type(None)), "modalities": (list, type(None)), - "source_type": (SeedDatasetSourceType, type(None)), - "rank": (SeedDatasetLoadingRank, type(None)), + "source_type": (str, type(None)), + "load_time": (SeedDatasetLoadTime, type(None)), "harm_categories": (list, type(None)), } for key in overlapping_keys: From 0d22a3d2b7b019ed2c7d0969bf33544e50eec55f Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Mon, 16 Mar 2026 21:04:53 +0000 Subject: [PATCH 12/29] . --- .../seed_datasets/seed_dataset_provider.py | 40 +++++++++---------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 22c8b0089f..ca38fce820 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -9,7 +9,7 @@ from tqdm import tqdm -from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetFilter, SeedDatasetLoadingRank, SeedDatasetMetadata +from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetFilter, SeedDatasetLoadTime, SeedDatasetMetadata from pyrit.models.seeds import SeedDataset logger = logging.getLogger(__name__) @@ -33,7 +33,7 @@ class SeedDatasetProvider(ABC): """ _registry: dict[str, type["SeedDatasetProvider"]] = {} - rank: SeedDatasetLoadingRank = SeedDatasetLoadingRank.UNKNOWN + load_time: SeedDatasetLoadTime = SeedDatasetLoadTime.UNINITIALIZED def __init_subclass__(cls, **kwargs: Any) -> None: """ @@ -138,18 +138,21 @@ def get_all_dataset_names(cls, filters: Optional[SeedDatasetFilter] = None) -> l dataset_names.add(provider.dataset_name) except Exception as e: - raise ValueError( - f"Could not get dataset name from {provider_class.__name__}: {e}") from e + raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) @classmethod def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter) -> bool: """ + Match filters against provider metadata. - Match the filter(s) with the metadata provided by the SeedDatasetProvider subclass. - By default, filters across dimensions (e.g. size, harm categories) are treated as AND - requirements. Filters within a dimension (e.g. SeedDatasetSize.SMALL, - SeedDatasetSize.LARGE) are treated as OR requirements. + Across dimensions (e.g. size + harm_categories): AND — all specified conditions must match. + Within a dimension (e.g. sizes=["small", "large"]): OR — metadata needs to overlap with + at least one value. + + Special tags: + - "all": bypasses all filtering, returns every dataset. + - "default": matches datasets that have tagged themselves as part of the curated set. Args: metadata (SeedDatasetMetadata): The metadata object extracted from the SeedDatasetProvider @@ -157,7 +160,7 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter filters (SeedDatasetFilter): The filter object provided by the user to get_all_dataset_names. Returns: - bool: Whether or not the filters match. + bool: Whether the filters match. """ # Tags if filters.tags and "all" in filters.tags: @@ -189,8 +192,8 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter if metadata.modalities and filters.modalities and not set(metadata.modalities) & set(filters.modalities): # noqa: SIM103 return False - # Rank - if metadata.rank and filters.ranks and metadata.rank not in filters.ranks: # noqa: SIM103 + # Load Time + if metadata.load_time and filters.load_times and metadata.load_time not in filters.load_times: # noqa: SIM103 return False return True @@ -235,11 +238,9 @@ async def fetch_datasets_async( # Validate dataset names if specified if dataset_names is not None: available_names = cls.get_all_dataset_names() - invalid_names = [ - name for name in dataset_names if name not in available_names] + invalid_names = [name for name in dataset_names if name not in available_names] if invalid_names: - raise ValueError( - f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") + raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -265,8 +266,7 @@ async def fetch_single_dataset( # Progress tracking total_count = len(cls._registry) - pbar = tqdm(total=total_count, - desc="Loading datasets - this can take a few minutes", unit="dataset") + pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -304,12 +304,10 @@ async def fetch_with_semaphore( logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] - combined_seeds = list( - existing_dataset.seeds) + list(dataset.seeds) + combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset - logger.info( - f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") + logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values()) From 45079b87d6c36eccb105cb8294311635d0c31625 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Mon, 16 Mar 2026 21:31:38 +0000 Subject: [PATCH 13/29] review --- .../seed_datasets/seed_dataset_provider.py | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index ca38fce820..1b4cbc99c1 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -123,18 +123,19 @@ def get_all_dataset_names(cls, filters: Optional[SeedDatasetFilter] = None) -> l # Parser ensures a standard metadata format metadata = provider._parse_metadata() - # "all" bypasses metadata filtering and returns every dataset. - if filters and filters.tags and "all" in filters.tags: - dataset_names.add(provider.dataset_name) - continue + if filters: + # "all" bypasses metadata filtering and returns every dataset + if filters.tags and "all" in filters.tags: + dataset_names.add(provider.dataset_name) + continue - if filters and not metadata: - # Datasets without metadata are skipped unless we want "all" - continue + # Datasets without metadata are skipped for all other filters + if not metadata: + continue - # Filters detected but no match -> don't add this dataset - if filters and metadata and not cls._match_filter(metadata=metadata, filters=filters): - continue + # Filters detected but no match -> don't add this dataset + if not cls._match_filter(metadata=metadata, filters=filters): + continue dataset_names.add(provider.dataset_name) except Exception as e: @@ -163,9 +164,17 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter bool: Whether the filters match. """ # Tags + # "all" defaults to all discovered datasets. if filters.tags and "all" in filters.tags: return True + # "default" checks for an initialized loading rank or the "default" curation tag. + if filters.tags and "default" in filters.tags and metadata.tags and "default" in metadata.tags: + return True + + if filters.tags and "default" in filters.tags and metadata.load_time != SeedDatasetLoadTime.UNINITIALIZED: + return True + # These lines all disable SIM103 because metadata and filters tags can be optional, so # directly checking for membership breaks type checking. From c04fb8780a93e2450ad7b954096f4d568d1068fc Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Tue, 17 Mar 2026 18:22:35 +0000 Subject: [PATCH 14/29] nits --- pyrit/datasets/seed_datasets/seed_metadata.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 6810329e21..e5a38e844b 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -5,6 +5,8 @@ from enum import Enum from typing import Literal, Optional +from pyrit.models.literals import PromptDataType + """ Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). @@ -22,12 +24,10 @@ # Documented expected values for string-typed metadata fields. # These are not enforced at runtime but serve as documentation. -SeedDatasetSizeLiteral = Literal["tiny", "small", "medium", "large", "huge"] +SeedDatasetSizeCategory = Literal["tiny", "small", "medium", "large", "huge"] """tiny (<10), small (10-99), medium (100-499), large (500-4999), huge (5000+)""" -SeedDatasetModalityLiteral = Literal["text", "image", "video", "audio"] - -SeedDatasetSourceTypeLiteral = Literal["remote", "local"] +SeedDatasetSourceType = Literal["remote", "local"] class SeedDatasetLoadTime(Enum): @@ -64,8 +64,8 @@ class SeedDatasetFilter: # "default" tag (think of this like a pinned or starred item). tags: Optional[set[str]] = None sizes: Optional[list[str]] = None - modalities: Optional[list[str]] = None - source_types: Optional[list[str]] = None + modalities: Optional[list[PromptDataType]] = None + source_types: Optional[list[SeedDatasetSourceType]] = None load_times: Optional[list[SeedDatasetLoadTime]] = None harm_categories: Optional[list[str]] = None @@ -78,8 +78,8 @@ class SeedDatasetMetadata: """ tags: Optional[set[str]] = None - size: Optional[str] = None - modalities: Optional[list[str]] = None - source_type: Optional[str] = None + size: Optional[SeedDatasetSizeCategory] = None + modalities: Optional[list[PromptDataType]] = None + source_type: Optional[SeedDatasetSourceType] = None load_time: SeedDatasetLoadTime = SeedDatasetLoadTime.UNINITIALIZED harm_categories: Optional[list[str]] = None From ba2a6d2397303f3ab80c9de83aa5e84841975a5d Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Tue, 17 Mar 2026 19:04:36 +0000 Subject: [PATCH 15/29] precommit --- doc/code/datasets/1_loading_datasets.ipynb | 174 ++---------------- doc/code/datasets/1_loading_datasets.py | 12 +- .../local/local_dataset_loader.py | 45 +---- .../remote/remote_dataset_loader.py | 2 +- .../seed_datasets/seed_dataset_provider.py | 12 +- pyrit/datasets/seed_datasets/seed_metadata.py | 50 ++++- .../test_seed_dataset_provider_integration.py | 30 +-- .../datasets/test_seed_dataset_metadata.py | 61 ++++++ .../datasets/test_seed_dataset_provider.py | 100 +++++----- .../unit/setup/test_load_default_datasets.py | 2 +- 10 files changed, 211 insertions(+), 277 deletions(-) diff --git a/doc/code/datasets/1_loading_datasets.ipynb b/doc/code/datasets/1_loading_datasets.ipynb index 7bc6aa60a2..b09f5ad1c7 100644 --- a/doc/code/datasets/1_loading_datasets.ipynb +++ b/doc/code/datasets/1_loading_datasets.ipynb @@ -52,74 +52,13 @@ "execution_count": null, "id": "1", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['adv_bench',\n", - " 'aegis_content_safety',\n", - " 'airt_fairness',\n", - " 'airt_fairness_yes_no',\n", - " 'airt_harassment',\n", - " 'airt_harms',\n", - " 'airt_hate',\n", - " 'airt_illegal',\n", - " 'airt_imminent_crisis',\n", - " 'airt_leakage',\n", - " 'airt_malware',\n", - " 'airt_misinformation',\n", - " 'airt_scams',\n", - " 'airt_sexual',\n", - " 'airt_violence',\n", - " 'aya_redteaming',\n", - " 'babelscape_alert',\n", - " 'beaver_tails',\n", - " 'cbt_bench',\n", - " 'ccp_sensitive_prompts',\n", - " 'dark_bench',\n", - " 'equitymedqa',\n", - " 'forbidden_questions',\n", - " 'garak_access_shell_commands',\n", - " 'garak_slur_terms_en',\n", - " 'garak_web_html_js',\n", - " 'harmbench',\n", - " 'harmbench_multimodal',\n", - " 'harmful_qa',\n", - " 'jbb_behaviors',\n", - " 'librai_do_not_answer',\n", - " 'llm_lat_harmful',\n", - " 'medsafetybench',\n", - " 'mental_health_crisis_multiturn_example',\n", - " 'ml_vlsu',\n", - " 'mlcommons_ailuminate',\n", - " 'multilingual_vulnerability',\n", - " 'or_bench_80k',\n", - " 'or_bench_hard',\n", - " 'or_bench_toxic',\n", - " 'pku_safe_rlhf',\n", - " 'promptintel',\n", - " 'psfuzz_steal_system_prompt',\n", - " 'pyrit_example_dataset',\n", - " 'red_team_social_bias',\n", - " 'salad_bench',\n", - " 'simple_safety_tests',\n", - " 'sorry_bench',\n", - " 'sosbench',\n", - " 'tdc23_redteaming',\n", - " 'toxic_chat',\n", - " 'transphobia_awareness',\n", - " 'xstest']" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from pyrit.datasets import SeedDatasetProvider\n", + "from pyrit.memory import CentralMemory\n", + "from pyrit.setup.initialization import IN_MEMORY, initialize_pyrit_async\n", "\n", - "SeedDatasetProvider.get_all_dataset_names()" + "await SeedDatasetProvider.get_all_dataset_names_async()" ] }, { @@ -137,64 +76,10 @@ "execution_count": null, "id": "3", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\r", - "Loading datasets - this can take a few minutes: 0%| | 0/58 [00:00 SeedDataset: logger.error(f"Failed to load local dataset from {self.file_path}: {e}") raise - def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: + async def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: """ Extract metadata from a local YAML file and coerce raw values into typed schema fields. @@ -105,49 +104,9 @@ def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: if not raw: return None - coerced = self._coerce_metadata_values(raw_metadata=raw) + coerced = SeedDatasetMetadata._coerce_metadata_values(raw_metadata=raw) return SeedDatasetMetadata(**coerced) - @staticmethod - def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: - """ - Convert YAML primitive values into the types expected by SeedDatasetMetadata. - - Applies .lower().strip() normalization to string values for size, modalities, - source_type, and harm_categories to prevent case/whitespace mismatches. - - Args: - raw_metadata (dict[str, Any]): Dictionary of field names to raw YAML-parsed values. - - Returns: - dict[str, Any]: Dictionary with values coerced to the correct types. - """ - coerced: dict[str, Any] = {} - for key, value in raw_metadata.items(): - if key == "tags" and isinstance(value, list): - coerced[key] = {v.strip().lower() if isinstance(v, str) else v for v in value} - elif key == "tags" and isinstance(value, str): - coerced[key] = {value.strip().lower()} - elif key == "size" and isinstance(value, str) or key == "source_type" and isinstance(value, str): - coerced[key] = value.strip().lower() - elif key == "load_time" and isinstance(value, str): - coerced[key] = SeedDatasetLoadTime(value.strip().lower()) - elif key == "modalities" and isinstance(value, list): - coerced[key] = [v.strip().lower() if isinstance(v, str) else v for v in value] - elif key == "modalities" and isinstance(value, str): - coerced[key] = [value.strip().lower()] - elif key == "harm_categories" and isinstance(value, list): - coerced[key] = [v.strip().lower() if isinstance(v, str) else v for v in value] - elif key == "harm_categories" and isinstance(value, str): - coerced[key] = [value.strip().lower()] - else: - # Unexpected type for a metadata field — skip it with a warning - # rather than passing garbage into SeedDatasetMetadata. - logger.warning( - f"Skipping metadata field '{key}' with unexpected type {type(value).__name__} (value: {value!r})" - ) - return coerced - def _register_local_datasets() -> None: """ diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 9587a743f0..b1110db6a4 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -288,7 +288,7 @@ def _load_dataset_sync() -> Any: logger.error(f"Failed to load HuggingFace dataset {dataset_name}: {e}") raise - def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: + async def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: """ Extract metadata from class attributes and format into SeedDatasetMetadata schema. diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 1b4cbc99c1..ceb57edf06 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -73,7 +73,7 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: Exception: If the dataset cannot be fetched or processed. """ - def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: + async def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: """ Parse provider-specific metadata into the shared schema. @@ -97,7 +97,7 @@ def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: return cls._registry.copy() @classmethod - def get_all_dataset_names(cls, filters: Optional[SeedDatasetFilter] = None) -> list[str]: + async def get_all_dataset_names_async(cls, filters: Optional[SeedDatasetFilter] = None) -> list[str]: """ Get the names of all registered datasets. @@ -111,7 +111,7 @@ def get_all_dataset_names(cls, filters: Optional[SeedDatasetFilter] = None) -> l ValueError: If no providers are registered or if providers cannot be instantiated. Example: - >>> names = SeedDatasetProvider.get_all_dataset_names() + >>> names = await SeedDatasetProvider.get_all_dataset_names_async() >>> print(f"Available datasets: {', '.join(names)}") """ dataset_names = set() @@ -121,7 +121,7 @@ def get_all_dataset_names(cls, filters: Optional[SeedDatasetFilter] = None) -> l provider = provider_class() # Parser ensures a standard metadata format - metadata = provider._parse_metadata() + metadata = await provider._parse_metadata() if filters: # "all" bypasses metadata filtering and returns every dataset @@ -158,7 +158,7 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter Args: metadata (SeedDatasetMetadata): The metadata object extracted from the SeedDatasetProvider subclass. - filters (SeedDatasetFilter): The filter object provided by the user to get_all_dataset_names. + filters (SeedDatasetFilter): The filter object provided by the user to get_all_dataset_names_async. Returns: bool: Whether the filters match. @@ -246,7 +246,7 @@ async def fetch_datasets_async( """ # Validate dataset names if specified if dataset_names is not None: - available_names = cls.get_all_dataset_names() + available_names = await cls.get_all_dataset_names_async() invalid_names = [name for name in dataset_names if name not in available_names] if invalid_names: raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index e5a38e844b..30192fb846 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -1,12 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import logging from dataclasses import dataclass from enum import Enum -from typing import Literal, Optional +from typing import Any, Literal, Optional from pyrit.models.literals import PromptDataType +logger = logging.getLogger(__name__) + + """ Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). @@ -16,7 +20,7 @@ - Local prompt files that store metadata at the top level SeedDatasetFilter is the user-facing filter schema consumed by -SeedDatasetProvider.get_all_dataset_names(). +SeedDatasetProvider.get_all_dataset_names_async(). Size and modality are string literals rather than enums for usability — callers don't need to import extra types to construct a filter. @@ -46,7 +50,7 @@ class SeedDatasetLoadTime(Enum): @dataclass class SeedDatasetFilter: """ - Filter object for datasets. Passed to `get_all_dataset_names` in + Filter object for datasets. Passed to `get_all_dataset_names_async` in SeedDatasetProvider. Most fields are optional. None means "don't filter on this axis." @@ -83,3 +87,43 @@ class SeedDatasetMetadata: source_type: Optional[SeedDatasetSourceType] = None load_time: SeedDatasetLoadTime = SeedDatasetLoadTime.UNINITIALIZED harm_categories: Optional[list[str]] = None + + @staticmethod + def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: + """ + Convert YAML primitive values into the types expected by SeedDatasetMetadata. + + Applies .lower().strip() normalization to string values for size, modalities, + source_type, and harm_categories to prevent case/whitespace mismatches. + + Args: + raw_metadata (dict[str, Any]): Dictionary of field names to raw YAML-parsed values. + + Returns: + dict[str, Any]: Dictionary with values coerced to the correct types. + """ + coerced: dict[str, Any] = {} + for key, value in raw_metadata.items(): + if key == "tags" and isinstance(value, list): + coerced[key] = {v.strip().lower() if isinstance(v, str) else v for v in value} + elif key == "tags" and isinstance(value, str): + coerced[key] = {value.strip().lower()} + elif key == "size" and isinstance(value, str) or key == "source_type" and isinstance(value, str): + coerced[key] = value.strip().lower() + elif key == "load_time" and isinstance(value, str): + coerced[key] = SeedDatasetLoadTime(value.strip().lower()) + elif key == "modalities" and isinstance(value, list): + coerced[key] = [v.strip().lower() if isinstance(v, str) else v for v in value] + elif key == "modalities" and isinstance(value, str): + coerced[key] = [value.strip().lower()] + elif key == "harm_categories" and isinstance(value, list): + coerced[key] = [v.strip().lower() if isinstance(v, str) else v for v in value] + elif key == "harm_categories" and isinstance(value, str): + coerced[key] = [value.strip().lower()] + else: + # Unexpected type for a metadata field — skip it with a warning + # rather than passing garbage into SeedDatasetMetadata. + logger.warning( + f"Skipping metadata field '{key}' with unexpected type {type(value).__name__} (value: {value!r})" + ) + return coerced diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index fdd58ec97f..438f213444 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -130,7 +130,7 @@ def test_filter_matches_correct_remote_provider(self): {"Large": large_cls, "Small": small_cls}, clear=True, ): - names = SeedDatasetProvider.get_all_dataset_names( + names = SeedDatasetProvider.get_all_dataset_names_async( filters=SeedDatasetFilter(sizes=["large"]), ) assert names == ["large_ds"] @@ -157,7 +157,7 @@ def test_filter_all_tag_returns_everything(self): {"A": cls1, "B": cls2}, clear=True, ): - names = SeedDatasetProvider.get_all_dataset_names( + names = SeedDatasetProvider.get_all_dataset_names_async( filters=SeedDatasetFilter(tags={"all"}), ) assert sorted(names) == ["ds_a", "ds_b"] @@ -184,7 +184,7 @@ def test_multi_axis_filter(self): {"TL": cls1, "IL": cls2}, clear=True, ): - names = SeedDatasetProvider.get_all_dataset_names( + names = SeedDatasetProvider.get_all_dataset_names_async( filters=SeedDatasetFilter( sizes=["large"], modalities=["text"], @@ -252,7 +252,7 @@ def test_local_filter_by_size(self, tmp_path): {"Large": large_cls, "Small": small_cls}, clear=True, ): - names = SeedDatasetProvider.get_all_dataset_names( + names = SeedDatasetProvider.get_all_dataset_names_async( filters=SeedDatasetFilter(sizes=["large"]), ) # dataset_name falls back to file stem when SeedDataset.from_yaml_file @@ -284,12 +284,12 @@ def test_local_filter_by_tags(self, tmp_path): ): # dataset_name falls back to file stem ("tagged") when # SeedDataset.from_yaml_file rejects extra keys like "tags" - matched = SeedDatasetProvider.get_all_dataset_names( + matched = SeedDatasetProvider.get_all_dataset_names_async( filters=SeedDatasetFilter(tags={"safety"}), ) assert matched == ["tagged"] - not_matched = SeedDatasetProvider.get_all_dataset_names( + not_matched = SeedDatasetProvider.get_all_dataset_names_async( filters=SeedDatasetFilter(tags={"unrelated"}), ) assert not_matched == [] @@ -313,11 +313,11 @@ def test_local_no_metadata_skipped(self, tmp_path): clear=True, ): # Without filters, the dataset is included - all_names = SeedDatasetProvider.get_all_dataset_names() + all_names = SeedDatasetProvider.get_all_dataset_names_async() assert "bare_local" in all_names # With filters, it's skipped (no metadata to match against) - filtered = SeedDatasetProvider.get_all_dataset_names( + filtered = SeedDatasetProvider.get_all_dataset_names_async( filters=SeedDatasetFilter(tags={"safety"}), ) assert filtered == [] @@ -392,7 +392,7 @@ async def test_user_discovers_and_fetches_filtered_dataset(self, tmp_path): clear=True, ): # --- Step 1: User filters by harm_categories --- - names = SeedDatasetProvider.get_all_dataset_names( + names = SeedDatasetProvider.get_all_dataset_names_async( filters=SeedDatasetFilter(harm_categories=["cybercrime"]), ) assert len(names) == 1 @@ -447,7 +447,7 @@ async def test_user_fetches_unfiltered(self, tmp_path): {"One": cls1, "Two": cls2}, clear=True, ): - names = SeedDatasetProvider.get_all_dataset_names() + names = SeedDatasetProvider.get_all_dataset_names_async() assert len(names) == 2 datasets = await SeedDatasetProvider.fetch_datasets_async() @@ -502,13 +502,13 @@ def test_all_tag_includes_datasets_without_metadata(self, tmp_path): clear=True, ): # Normal filter skips it - filtered = SeedDatasetProvider.get_all_dataset_names( + filtered = SeedDatasetProvider.get_all_dataset_names_async( filters=SeedDatasetFilter(tags={"safety"}), ) assert filtered == [] # 'all' includes it - all_names = SeedDatasetProvider.get_all_dataset_names( + all_names = SeedDatasetProvider.get_all_dataset_names_async( filters=SeedDatasetFilter(tags={"all"}), ) assert "bare_dataset" in all_names @@ -538,13 +538,13 @@ def test_all_tag_ignores_other_filter_axes(self, tmp_path): clear=True, ): # Size filter alone would exclude it - size_filtered = SeedDatasetProvider.get_all_dataset_names( + size_filtered = SeedDatasetProvider.get_all_dataset_names_async( filters=SeedDatasetFilter(sizes=["large"]), ) assert size_filtered == [] # 'all' tag overrides the size filter - all_names = SeedDatasetProvider.get_all_dataset_names( + all_names = SeedDatasetProvider.get_all_dataset_names_async( filters=SeedDatasetFilter(tags={"all"}, sizes=["large"]), ) assert "small" in all_names @@ -585,7 +585,7 @@ def test_all_tag_with_mixed_metadata_and_bare_datasets(self, tmp_path): {"Rich": rich_cls, "Bare": bare_cls}, clear=True, ): - all_names = SeedDatasetProvider.get_all_dataset_names( + all_names = SeedDatasetProvider.get_all_dataset_names_async( filters=SeedDatasetFilter(tags={"all"}), ) assert len(all_names) == 2 diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py index e9096a470b..f3635b7570 100644 --- a/tests/unit/datasets/test_seed_dataset_metadata.py +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -128,6 +128,67 @@ def test_harm_categories_value(self): assert "cybercrime" in metadata.harm_categories +class TestMetadataCoercion: + """ + Test that _coerce_metadata_values correctly normalizes raw YAML + values into the types expected by SeedDatasetMetadata. + """ + + def test_tags_list_coerced_to_set(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"tags": ["safety", "default"]}) + assert result["tags"] == {"safety", "default"} + assert isinstance(result["tags"], set) + + def test_tags_string_coerced_to_set(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"tags": "safety"}) + assert result["tags"] == {"safety"} + assert isinstance(result["tags"], set) + + def test_tags_normalized_lower_strip(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"tags": [" Safety ", " DEFAULT"]}) + assert result["tags"] == {"safety", "default"} + + def test_size_coerced_to_lowercase_string(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"size": " Large "}) + assert result["size"] == "large" + + def test_source_type_coerced_to_lowercase_string(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"source_type": " Remote "}) + assert result["source_type"] == "remote" + + def test_load_time_coerced_to_enum(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"load_time": "fast"}) + assert result["load_time"] == SeedDatasetLoadTime.FAST + + def test_load_time_normalized_strip_lower(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"load_time": " Slow "}) + assert result["load_time"] == SeedDatasetLoadTime.SLOW + + def test_modalities_list_coerced_lowercase(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"modalities": ["Text", " IMAGE "]}) + assert result["modalities"] == ["text", "image"] + + def test_modalities_string_coerced_to_list(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"modalities": "text"}) + assert result["modalities"] == ["text"] + + def test_harm_categories_list_coerced_lowercase(self): + result = SeedDatasetMetadata._coerce_metadata_values( + raw_metadata={"harm_categories": ["Violence", " Cybercrime "]} + ) + assert result["harm_categories"] == ["violence", "cybercrime"] + + def test_harm_categories_string_coerced_to_list(self): + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"harm_categories": "violence"}) + assert result["harm_categories"] == ["violence"] + + def test_unknown_type_skipped_with_warning(self, caplog): + """Unexpected types are dropped and logged, not passed through.""" + result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"tags": 12345}) + assert "tags" not in result + assert "Skipping metadata field" in caplog.text + + class TestFilterProperties: """ Test that the filter fields populate correctly. diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index bf7a54a3ca..1306177db7 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -70,29 +70,33 @@ async def fetch_dataset(self): assert "DynamicTestProvider" in providers assert providers["DynamicTestProvider"] == DynamicTestProvider - def test_get_all_dataset_names(self): + @pytest.mark.asyncio + async def test_get_all_dataset_names(self): """Test getting all dataset names.""" # Mock the registry to ensure deterministic results - mock_provider_cls = MagicMock() + mock_provider_cls = MagicMock(__name__="TestProvider") mock_provider_instance = mock_provider_cls.return_value mock_provider_instance.dataset_name = "test_dataset" + mock_provider_instance._parse_metadata = AsyncMock(return_value=None) with patch.dict(SeedDatasetProvider._registry, {"TestProvider": mock_provider_cls}, clear=True): - names = SeedDatasetProvider.get_all_dataset_names() + names = await SeedDatasetProvider.get_all_dataset_names_async() assert names == ["test_dataset"] @pytest.mark.asyncio async def test_fetch_datasets_async(self): """Test fetching all datasets.""" # Mock providers - mock_provider1 = MagicMock() + mock_provider1 = MagicMock(__name__="P1") mock_provider1.return_value.dataset_name = "d1" + mock_provider1.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider1.return_value.fetch_dataset = AsyncMock( return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) - mock_provider2 = MagicMock() + mock_provider2 = MagicMock(__name__="P2") mock_provider2.return_value.dataset_name = "d2" + mock_provider2.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider2.return_value.fetch_dataset = AsyncMock( return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) @@ -104,14 +108,16 @@ async def test_fetch_datasets_async(self): @pytest.mark.asyncio async def test_fetch_datasets_async_with_filter(self): """Test fetching datasets with filter.""" - mock_provider1 = MagicMock() + mock_provider1 = MagicMock(__name__="P1") mock_provider1.return_value.dataset_name = "d1" + mock_provider1.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider1.return_value.fetch_dataset = AsyncMock( return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) - mock_provider2 = MagicMock() + mock_provider2 = MagicMock(__name__="P2") mock_provider2.return_value.dataset_name = "d2" + mock_provider2.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider2.return_value.fetch_dataset = AsyncMock(side_effect=Exception("Should not be called")) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): @@ -122,14 +128,16 @@ async def test_fetch_datasets_async_with_filter(self): @pytest.mark.asyncio async def test_fetch_datasets_async_invalid_dataset_name(self): """Test that fetch_datasets_async raises ValueError for invalid dataset names.""" - mock_provider1 = MagicMock() + mock_provider1 = MagicMock(__name__="P1") mock_provider1.return_value.dataset_name = "d1" + mock_provider1.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider1.return_value.fetch_dataset = AsyncMock( return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) - mock_provider2 = MagicMock() + mock_provider2 = MagicMock(__name__="P2") mock_provider2.return_value.dataset_name = "d2" + mock_provider2.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider2.return_value.fetch_dataset = AsyncMock( return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) @@ -251,10 +259,11 @@ async def test_fetch_dataset_with_custom_config(self, mock_darkbench_data): class TestMetadataParsingRemote: """Test metadata parsing and filter matching for remote providers.""" - def test_parse_metadata_from_class_attrs(self): + @pytest.mark.asyncio + async def test_parse_metadata_from_class_attrs(self): """Test _parse_metadata correctly extracts class-level metadata attributes.""" loader = _HarmBenchDataset() - metadata = loader._parse_metadata() + metadata = await loader._parse_metadata() assert metadata is not None assert metadata.tags == {"default", "safety"} assert metadata.size == "large" @@ -343,15 +352,16 @@ def test_empty_filter(self): filters = SeedDatasetFilter() assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) - def test_no_metadata(self): + @pytest.mark.asyncio + async def test_no_metadata(self): """Provider without metadata is skipped when filters are applied.""" - mock_provider_cls = MagicMock() + mock_provider_cls = MagicMock(__name__="NoProv") mock_provider_instance = mock_provider_cls.return_value mock_provider_instance.dataset_name = "no_metadata" - mock_provider_instance._parse_metadata.return_value = None + mock_provider_instance._parse_metadata = AsyncMock(return_value=None) with patch.dict(SeedDatasetProvider._registry, {"NoProv": mock_provider_cls}, clear=True): - names = SeedDatasetProvider.get_all_dataset_names(filters=SeedDatasetFilter(tags={"safety"})) + names = await SeedDatasetProvider.get_all_dataset_names_async(filters=SeedDatasetFilter(tags={"safety"})) assert names == [] @@ -371,7 +381,8 @@ def _write_yaml(self, tmp_path, name, content): path.write_text(content) return path - def test_parse_metadata_extracts_fields(self, tmp_path): + @pytest.mark.asyncio + async def test_parse_metadata_extracts_fields(self, tmp_path): """Test _parse_metadata correctly extracts metadata fields from YAML.""" yaml_path = self._write_yaml( tmp_path, @@ -386,11 +397,12 @@ def test_parse_metadata_extracts_fields(self, tmp_path): """), ) loader = self._make_loader(yaml_path) - metadata = loader._parse_metadata() + metadata = await loader._parse_metadata() assert metadata is not None assert metadata.harm_categories == ["violence"] - def test_all_tag(self, tmp_path): + @pytest.mark.asyncio + async def test_all_tag(self, tmp_path): """Filter with tags={'all'} matches regardless of metadata types.""" yaml_path = self._write_yaml( tmp_path, @@ -407,12 +419,13 @@ def test_all_tag(self, tmp_path): """), ) loader = self._make_loader(yaml_path) - metadata = loader._parse_metadata() + metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(tags={"all"}) assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) - def test_tags(self, tmp_path): + @pytest.mark.asyncio + async def test_tags(self, tmp_path): """YAML produces tags as list; set intersection in _match_filter expects a set.""" yaml_path = self._write_yaml( tmp_path, @@ -428,12 +441,13 @@ def test_tags(self, tmp_path): """), ) loader = self._make_loader(yaml_path) - metadata = loader._parse_metadata() + metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(tags={"safety"}) assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) - def test_sizes(self, tmp_path): + @pytest.mark.asyncio + async def test_sizes(self, tmp_path): """YAML produces size as string; _match_filter compares against enum values.""" yaml_path = self._write_yaml( tmp_path, @@ -447,12 +461,13 @@ def test_sizes(self, tmp_path): """), ) loader = self._make_loader(yaml_path) - metadata = loader._parse_metadata() + metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(sizes=["large"]) assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) - def test_modalities(self, tmp_path): + @pytest.mark.asyncio + async def test_modalities(self, tmp_path): """YAML produces modalities as list of strings; _match_filter uses enum values.""" yaml_path = self._write_yaml( tmp_path, @@ -467,12 +482,13 @@ def test_modalities(self, tmp_path): """), ) loader = self._make_loader(yaml_path) - metadata = loader._parse_metadata() + metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(modalities=["text"]) assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) - def test_sources(self, tmp_path): + @pytest.mark.asyncio + async def test_sources(self, tmp_path): """YAML produces source_type as string; _match_filter compares against enum values.""" yaml_path = self._write_yaml( tmp_path, @@ -486,12 +502,13 @@ def test_sources(self, tmp_path): """), ) loader = self._make_loader(yaml_path) - metadata = loader._parse_metadata() + metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(source_types=["remote"]) assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) - def test_ranks(self, tmp_path): + @pytest.mark.asyncio + async def test_ranks(self, tmp_path): """YAML produces load_time as string; _match_filter compares against enum values.""" yaml_path = self._write_yaml( tmp_path, @@ -505,12 +522,13 @@ def test_ranks(self, tmp_path): """), ) loader = self._make_loader(yaml_path) - metadata = loader._parse_metadata() + metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(load_times=[SeedDatasetLoadTime.FAST]) assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) - def test_harm_categories(self, tmp_path): + @pytest.mark.asyncio + async def test_harm_categories(self, tmp_path): """Both YAML and filter use list[str], so intersection works correctly.""" yaml_path = self._write_yaml( tmp_path, @@ -526,12 +544,13 @@ def test_harm_categories(self, tmp_path): """), ) loader = self._make_loader(yaml_path) - metadata = loader._parse_metadata() + metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(harm_categories=["violence"]) assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) - def test_empty_filter(self, tmp_path): + @pytest.mark.asyncio + async def test_empty_filter(self, tmp_path): """Empty filter matches everything.""" yaml_path = self._write_yaml( tmp_path, @@ -546,12 +565,13 @@ def test_empty_filter(self, tmp_path): """), ) loader = self._make_loader(yaml_path) - metadata = loader._parse_metadata() + metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter() assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) - def test_no_metadata(self, tmp_path): + @pytest.mark.asyncio + async def test_no_metadata(self, tmp_path): """YAML without any metadata fields returns None from _parse_metadata.""" yaml_path = self._write_yaml( tmp_path, @@ -564,7 +584,7 @@ def test_no_metadata(self, tmp_path): """), ) loader = self._make_loader(yaml_path) - metadata = loader._parse_metadata() + metadata = await loader._parse_metadata() assert metadata is None @@ -585,16 +605,14 @@ def _get_local_prompt_files() -> list: return sorted(local_dir.glob("**/*.prompt")) + sorted(local_dir.glob("**/*.yaml")) @pytest.mark.parametrize("prompt_file", _get_local_prompt_files.__func__(), ids=lambda p: p.stem) - def test_parse_metadata_does_not_crash(self, prompt_file): + @pytest.mark.asyncio + async def test_parse_metadata_does_not_crash(self, prompt_file): """_parse_metadata must not raise on any real local dataset file.""" loader = _LocalDatasetLoader.__new__(_LocalDatasetLoader) loader.file_path = prompt_file loader._dataset_name = prompt_file.stem - # This must not raise — if a YAML key collides with a metadata field - # name but holds an incompatible value, the coercion layer should - # either handle it or skip it gracefully. - metadata = loader._parse_metadata() + metadata = await loader._parse_metadata() # metadata can be None (no matching fields) or a valid SeedDatasetMetadata if metadata is not None: assert isinstance(metadata, SeedDatasetMetadata) @@ -624,7 +642,7 @@ def test_no_yaml_key_shadows_metadata_field_with_wrong_type(self, prompt_file): loader._dataset_name = prompt_file.stem raw = {k: data[k] for k in overlapping_keys} - coerced = _LocalDatasetLoader._coerce_metadata_values(raw_metadata=raw) + coerced = SeedDatasetMetadata._coerce_metadata_values(raw_metadata=raw) metadata = SeedDatasetMetadata(**coerced) # Verify coerced types match expectations diff --git a/tests/unit/setup/test_load_default_datasets.py b/tests/unit/setup/test_load_default_datasets.py index 54655f86e5..d701a64450 100644 --- a/tests/unit/setup/test_load_default_datasets.py +++ b/tests/unit/setup/test_load_default_datasets.py @@ -187,7 +187,7 @@ async def test_all_required_datasets_available_in_seed_provider(self) -> None: exists in the SeedDatasetProvider registry. """ # Get all available dataset names from SeedDatasetProvider - available_datasets = set(SeedDatasetProvider.get_all_dataset_names()) + available_datasets = set(await SeedDatasetProvider.get_all_dataset_names_async()) # Get ScenarioRegistry to discover all scenarios registry = ScenarioRegistry.get_registry_singleton() From 29726ed5e1568de65ccb197a6d5176fb6e057364 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Tue, 17 Mar 2026 20:08:35 +0000 Subject: [PATCH 16/29] . --- .../seed_datasets/seed_dataset_provider.py | 70 +++--- pyrit/datasets/seed_datasets/seed_metadata.py | 52 ++++- .../datasets/test_seed_dataset_provider.py | 206 ++++++++++++++++++ 3 files changed, 293 insertions(+), 35 deletions(-) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index ceb57edf06..4cfd4cd3ad 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -147,9 +147,16 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter """ Match filters against provider metadata. - Across dimensions (e.g. size + harm_categories): AND — all specified conditions must match. - Within a dimension (e.g. sizes=["small", "large"]): OR — metadata needs to overlap with - at least one value. + When strict_match is False (default): + - Across dimensions (e.g. size + harm_categories): AND — all specified conditions must match. + - Within a dimension (e.g. sizes=["small", "large"]): OR — metadata needs to overlap with + at least one value. + + When strict_match is True: + - Across dimensions: AND (same as default). + - Within set-like dimensions (tags, harm_categories, modalities): AND — metadata must + contain ALL requested values, not just one. + - Within scalar dimensions (size, source_type, load_time): unchanged (membership check). Special tags: - "all": bypasses all filtering, returns every dataset. @@ -164,42 +171,53 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter bool: Whether the filters match. """ # Tags - # "all" defaults to all discovered datasets. + # "all" always bypasses all filtering, regardless of strict_match. if filters.tags and "all" in filters.tags: return True - # "default" checks for an initialized loading rank or the "default" curation tag. - if filters.tags and "default" in filters.tags and metadata.tags and "default" in metadata.tags: - return True - - if filters.tags and "default" in filters.tags and metadata.load_time != SeedDatasetLoadTime.UNINITIALIZED: - return True - - # These lines all disable SIM103 because metadata and filters tags can be optional, so - # directly checking for membership breaks type checking. - - if metadata.tags and filters.tags and not (filters.tags & metadata.tags): # noqa: SIM103 - return False + # "default" tag handling depends on strict_match: + # - Without strict_match: "default" alone is enough to match if the dataset + # has a "default" tag or an initialized load_time. + # - With strict_match: "default" is treated like any other tag — ALL + # requested tags (including "default") must be present in the dataset. + if not filters.strict_match: + if filters.tags and "default" in filters.tags and metadata.tags and "default" in metadata.tags: + return True + + if filters.tags and "default" in filters.tags and metadata.load_time != SeedDatasetLoadTime.UNINITIALIZED: + return True + + if metadata.tags and filters.tags: + if filters.strict_match: + # All requested tags must be present in the dataset + if not filters.tags <= metadata.tags: + return False + elif not (filters.tags & metadata.tags): + return False # Size - if metadata.size and filters.sizes and metadata.size not in filters.sizes: # noqa: SIM103 + if metadata.size and filters.sizes and metadata.size not in filters.sizes: return False # Harm Categories - if ( - metadata.harm_categories - and filters.harm_categories - and not set(metadata.harm_categories) & set(filters.harm_categories) - ): # noqa: SIM103 - return False + if metadata.harm_categories and filters.harm_categories: + if filters.strict_match: + if not set(filters.harm_categories) <= set(metadata.harm_categories): + return False + elif not set(metadata.harm_categories) & set(filters.harm_categories): + return False # Source Type - if metadata.source_type and filters.source_types and metadata.source_type not in filters.source_types: # noqa: SIM103 + if metadata.source_type and filters.source_types and metadata.source_type not in filters.source_types: return False # Modalities - if metadata.modalities and filters.modalities and not set(metadata.modalities) & set(filters.modalities): # noqa: SIM103 - return False + if metadata.modalities and filters.modalities: + if filters.strict_match: + if not set(filters.modalities) <= set(metadata.modalities): + return False + elif not set(metadata.modalities) & set(filters.modalities): + return False # Load Time if metadata.load_time and filters.load_times and metadata.load_time not in filters.load_times: # noqa: SIM103 diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 30192fb846..2af12fe01a 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -55,17 +55,22 @@ class SeedDatasetFilter: Most fields are optional. None means "don't filter on this axis." - Exception for load_times, which defaults to UNINITIALIZED. - - By default, filtering is OR-wise across filter categories and OR-wise within - filter categories. + By default, filtering is AND across categories (all must match) and + OR within categories (any overlap is sufficient). + + Setting strict_match=True changes within-category behavior to AND + for set-like fields (tags, harm_categories, modalities). + + Special tag behavior: + - "all": A magic bypass that returns every discoverable dataset. When "all" + is present, ALL other filter fields and strict_match are ignored. This + operates at the get_all_dataset_names_async level — _match_filter is not + even called for datasets without metadata. + - "default": Matches datasets that have "default" in their tags or have an + initialized load_time. With strict_match=True, "default" loses its + special shortcut behavior and is treated as a normal tag. """ - # Tags are a top-level set of labels that assist with filtering. - # The tag "all" will return every discoverable dataset. - # The tag "default" will return every dataset with an initialized - # load_time (i.e., SeedDatasetLoadTime != UNINITIALIZED.) or an explicit - # "default" tag (think of this like a pinned or starred item). tags: Optional[set[str]] = None sizes: Optional[list[str]] = None modalities: Optional[list[PromptDataType]] = None @@ -73,6 +78,35 @@ class SeedDatasetFilter: load_times: Optional[list[SeedDatasetLoadTime]] = None harm_categories: Optional[list[str]] = None + # Setting this to True forces AND-wise filtering within set-like categories. + # "all" tag still bypasses everything regardless of this flag. + strict_match: bool = False + + def __post_init__(self) -> None: + """Validate filter configuration.""" + if self.tags and "all" in self.tags and len(self.tags) > 1: + logger.warning( + "Filter has 'all' combined with other tags %s. " + "'all' bypasses all filtering — other tags will be ignored.", + self.tags - {"all"}, + ) + if self.tags and "all" in self.tags and self.strict_match: + logger.warning( + "Filter has 'all' with strict_match=True. 'all' bypasses all filtering — strict_match has no effect." + ) + if ( + self.tags + and "all" in self.tags + and any( + f is not None + for f in [self.sizes, self.modalities, self.source_types, self.load_times, self.harm_categories] + ) + ): + logger.warning( + "Filter has 'all' combined with other filter fields. " + "'all' bypasses all filtering — other fields will be ignored." + ) + @dataclass(frozen=True) class SeedDatasetMetadata: diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index 1306177db7..05ea5f86ca 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -365,6 +365,212 @@ async def test_no_metadata(self): assert names == [] +class TestStrictMatchFiltering: + """Test strict_match behavior in SeedDatasetFilter.""" + + def test_strict_tags_all_present_matches(self): + """strict_match requires ALL filter tags to be present in metadata.""" + metadata = SeedDatasetMetadata(tags={"safety", "default", "curated"}) + filters = SeedDatasetFilter(tags={"safety", "default"}, strict_match=True) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_strict_tags_partial_overlap_fails(self): + """strict_match rejects if metadata is missing any requested tag.""" + metadata = SeedDatasetMetadata(tags={"safety"}) + filters = SeedDatasetFilter(tags={"safety", "default"}, strict_match=True) + assert not SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_nonstrict_tags_partial_overlap_passes(self): + """Without strict_match, any tag overlap is sufficient.""" + metadata = SeedDatasetMetadata(tags={"safety"}) + filters = SeedDatasetFilter(tags={"safety", "default"}, strict_match=False) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_strict_harm_categories_all_present_matches(self): + """strict_match requires ALL filter harm_categories present in metadata.""" + metadata = SeedDatasetMetadata(harm_categories=["violence", "cybercrime", "illegal"]) + filters = SeedDatasetFilter(harm_categories=["violence", "cybercrime"], strict_match=True) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_strict_harm_categories_partial_fails(self): + """strict_match rejects if metadata is missing any requested harm category.""" + metadata = SeedDatasetMetadata(harm_categories=["violence"]) + filters = SeedDatasetFilter(harm_categories=["violence", "cybercrime"], strict_match=True) + assert not SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_strict_modalities_all_present_matches(self): + """strict_match requires ALL filter modalities present in metadata.""" + metadata = SeedDatasetMetadata(modalities=["text", "image", "audio"]) + filters = SeedDatasetFilter(modalities=["text", "image"], strict_match=True) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_strict_modalities_partial_fails(self): + """strict_match rejects if metadata is missing any requested modality.""" + metadata = SeedDatasetMetadata(modalities=["text"]) + filters = SeedDatasetFilter(modalities=["text", "image"], strict_match=True) + assert not SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_strict_size_unchanged(self): + """strict_match doesn't change size behavior — still membership check.""" + metadata = SeedDatasetMetadata(size="large") + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(sizes=["large"], strict_match=True), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(sizes=["small"], strict_match=True), + ) + + def test_strict_cross_axis_and(self): + """strict_match with multiple axes: all must match.""" + metadata = SeedDatasetMetadata( + tags={"safety", "default"}, + size="large", + harm_categories=["violence", "cybercrime"], + ) + # Both axes satisfied + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter( + tags={"safety"}, + harm_categories=["violence"], + strict_match=True, + ), + ) + # harm_categories axis fails (missing "illegal") + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter( + tags={"safety"}, + harm_categories=["violence", "illegal"], + strict_match=True, + ), + ) + + def test_strict_all_tag_still_bypasses(self): + """tags={'all'} still bypasses everything even with strict_match.""" + metadata = SeedDatasetMetadata(tags={"safety"}) + filters = SeedDatasetFilter(tags={"all"}, strict_match=True) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_strict_default_plus_other_tags_requires_both(self): + """With strict_match, 'default' is a normal tag — all must be present.""" + metadata = SeedDatasetMetadata(tags={"default", "safety"}) + # Both present → match + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(tags={"default", "safety"}, strict_match=True), + ) + # Missing "curated" → reject + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(tags={"default", "safety", "curated"}, strict_match=True), + ) + + def test_nonstrict_default_is_shortcut(self): + """Without strict_match, 'default' in filter tags is a shortcut match.""" + # Dataset has "default" tag → matches even without other filter tags present + metadata = SeedDatasetMetadata(tags={"default"}) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(tags={"default", "nonexistent"}), + ) + + def test_strict_default_without_tag_on_dataset_fails(self): + """With strict_match, dataset must actually have 'default' in tags.""" + metadata = SeedDatasetMetadata(tags={"safety"}, load_time=SeedDatasetLoadTime.FAST) + # Without strict, "default" would match via initialized load_time + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(tags={"default"}), + ) + # With strict, "default" must be in metadata.tags + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(tags={"default"}, strict_match=True), + ) + + +class TestFilterValidation: + """Test that invalid or contradictory filter configurations are caught early.""" + + def test_all_with_strict_match_warns(self, caplog): + """'all' + strict_match logs a warning since strict has no effect.""" + SeedDatasetFilter(tags={"all"}, strict_match=True) + assert "strict_match has no effect" in caplog.text + + def test_all_with_other_tags_warns(self, caplog): + """'all' combined with other tags logs a warning.""" + SeedDatasetFilter(tags={"all", "safety"}) + assert "other tags will be ignored" in caplog.text + + def test_all_with_other_fields_warns(self, caplog): + """'all' combined with size/modality/etc logs a warning.""" + SeedDatasetFilter(tags={"all"}, sizes=["large"]) + assert "other fields will be ignored" in caplog.text + + def test_all_alone_no_warning(self, caplog): + """'all' by itself does not warn.""" + SeedDatasetFilter(tags={"all"}) + assert caplog.text == "" + + def test_all_bypasses_match_filter_entirely(self): + """'all' returns True from _match_filter regardless of metadata content.""" + # Metadata with no overlap to any filter field + metadata = SeedDatasetMetadata( + tags={"unrelated"}, + size="tiny", + modalities=["audio"], + harm_categories=["nothing"], + ) + # Filter that would normally reject everything about this metadata + filters = SeedDatasetFilter( + tags={"all"}, + sizes=["huge"], + modalities=["text"], + harm_categories=["violence"], + strict_match=True, + ) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + @pytest.mark.asyncio + async def test_all_includes_datasets_without_metadata(self): + """'all' in get_all_dataset_names_async includes providers with no metadata.""" + mock_cls = MagicMock(__name__="BareProv") + mock_cls.return_value.dataset_name = "bare" + mock_cls.return_value._parse_metadata = AsyncMock(return_value=None) + + with patch.dict(SeedDatasetProvider._registry, {"Bare": mock_cls}, clear=True): + # Without 'all', bare datasets are skipped + names = await SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert names == [] + + # With 'all', bare datasets are included + names = await SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"all"}), + ) + assert names == ["bare"] + + @pytest.mark.asyncio + async def test_all_skips_match_filter_call(self): + """'all' in get_all_dataset_names_async doesn't call _match_filter at all.""" + mock_cls = MagicMock(__name__="Prov") + mock_cls.return_value.dataset_name = "test" + mock_cls.return_value._parse_metadata = AsyncMock(return_value=None) + + with ( + patch.dict(SeedDatasetProvider._registry, {"P": mock_cls}, clear=True), + patch.object(SeedDatasetProvider, "_match_filter") as mock_match, + ): + await SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"all"}), + ) + mock_match.assert_not_called() + + class TestMetadataParsingLocal: """Test metadata parsing and filter matching for local YAML providers.""" From a8ae673017c831e0e44d3c69a322b5ca32acd8c5 Mon Sep 17 00:00:00 2001 From: Victor Valbuena <50061128+ValbuenaVC@users.noreply.github.com> Date: Wed, 18 Mar 2026 13:23:28 -0700 Subject: [PATCH 17/29] Update pyrit/datasets/seed_datasets/seed_dataset_provider.py Co-authored-by: hannahwestra25 --- pyrit/datasets/seed_datasets/seed_dataset_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 4cfd4cd3ad..2ba5ae0954 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -190,7 +190,7 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter if metadata.tags and filters.tags: if filters.strict_match: # All requested tags must be present in the dataset - if not filters.tags <= metadata.tags: + if filters.tags > metadata.tags: return False elif not (filters.tags & metadata.tags): return False From 9e4b51b13cea6d1512801d84d0b99812112319af Mon Sep 17 00:00:00 2001 From: Victor Valbuena <50061128+ValbuenaVC@users.noreply.github.com> Date: Wed, 18 Mar 2026 13:23:47 -0700 Subject: [PATCH 18/29] Update pyrit/datasets/seed_datasets/seed_metadata.py Co-authored-by: hannahwestra25 --- pyrit/datasets/seed_datasets/seed_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 2af12fe01a..c3e0cb4ad5 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -127,7 +127,7 @@ def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: """ Convert YAML primitive values into the types expected by SeedDatasetMetadata. - Applies .lower().strip() normalization to string values for size, modalities, + Applies .lower().strip() normalization to string values for tags, load_time, size, modalities, source_type, and harm_categories to prevent case/whitespace mismatches. Args: From a9d30c830cd495e4cd98eb4bd00ef08b97af1997 Mon Sep 17 00:00:00 2001 From: Victor Valbuena <50061128+ValbuenaVC@users.noreply.github.com> Date: Wed, 18 Mar 2026 13:31:03 -0700 Subject: [PATCH 19/29] Update pyrit/datasets/seed_datasets/seed_metadata.py Co-authored-by: hannahwestra25 --- pyrit/datasets/seed_datasets/seed_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index c3e0cb4ad5..83cf1d582a 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -29,7 +29,7 @@ # Documented expected values for string-typed metadata fields. # These are not enforced at runtime but serve as documentation. SeedDatasetSizeCategory = Literal["tiny", "small", "medium", "large", "huge"] -"""tiny (<10), small (10-99), medium (100-499), large (500-4999), huge (5000+)""" +# tiny (<10), small (10-99), medium (100-499), large (500-4999), huge (5000+) SeedDatasetSourceType = Literal["remote", "local"] From 5d4c53fda615c3d173a3b9f7822b095ddbbf2e50 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 18 Mar 2026 21:14:29 +0000 Subject: [PATCH 20/29] nits --- .../seed_datasets/seed_dataset_provider.py | 39 ++-- pyrit/datasets/seed_datasets/seed_metadata.py | 10 +- .../datasets/test_seed_dataset_provider.py | 185 +++++++++++------- 3 files changed, 146 insertions(+), 88 deletions(-) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 2ba5ae0954..10c977fc9e 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -134,18 +134,20 @@ async def get_all_dataset_names_async(cls, filters: Optional[SeedDatasetFilter] continue # Filters detected but no match -> don't add this dataset - if not cls._match_filter(metadata=metadata, filters=filters): + if not cls._match_filter_to_metadata(metadata=metadata, filters=filters): continue dataset_names.add(provider.dataset_name) except Exception as e: - raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e + raise ValueError( + f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) @classmethod - def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter) -> bool: + def _match_filter_to_metadata(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter) -> bool: """ - Match filters against provider metadata. + Match filters against provider metadata. The filter is used to determine whether a match is found + against the metadat object. When strict_match is False (default): - Across dimensions (e.g. size + harm_categories): AND — all specified conditions must match. @@ -178,14 +180,12 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter # "default" tag handling depends on strict_match: # - Without strict_match: "default" alone is enough to match if the dataset # has a "default" tag or an initialized load_time. - # - With strict_match: "default" is treated like any other tag — ALL - # requested tags (including "default") must be present in the dataset. - if not filters.strict_match: - if filters.tags and "default" in filters.tags and metadata.tags and "default" in metadata.tags: - return True - - if filters.tags and "default" in filters.tags and metadata.load_time != SeedDatasetLoadTime.UNINITIALIZED: - return True + # - With strict_match: "default" is treated like any other tag. Every + # requested tag (including "default") must be present in the dataset + # to count as a match. + if not filters.strict_match and filters.tags and "default" in filters.tags and \ + metadata.tags and "default" in metadata.tags: + return True if metadata.tags and filters.tags: if filters.strict_match: @@ -265,9 +265,11 @@ async def fetch_datasets_async( # Validate dataset names if specified if dataset_names is not None: available_names = await cls.get_all_dataset_names_async() - invalid_names = [name for name in dataset_names if name not in available_names] + invalid_names = [ + name for name in dataset_names if name not in available_names] if invalid_names: - raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") + raise ValueError( + f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -293,7 +295,8 @@ async def fetch_single_dataset( # Progress tracking total_count = len(cls._registry) - pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") + pbar = tqdm(total=total_count, + desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -331,10 +334,12 @@ async def fetch_with_semaphore( logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] - combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) + combined_seeds = list( + existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset - logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") + logger.info( + f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values()) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 83cf1d582a..0dc114faf0 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -50,6 +50,7 @@ class SeedDatasetLoadTime(Enum): @dataclass class SeedDatasetFilter: """ + Filter object for datasets. Passed to `get_all_dataset_names_async` in SeedDatasetProvider. @@ -139,7 +140,8 @@ def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: coerced: dict[str, Any] = {} for key, value in raw_metadata.items(): if key == "tags" and isinstance(value, list): - coerced[key] = {v.strip().lower() if isinstance(v, str) else v for v in value} + coerced[key] = {v.strip().lower() if isinstance( + v, str) else v for v in value} elif key == "tags" and isinstance(value, str): coerced[key] = {value.strip().lower()} elif key == "size" and isinstance(value, str) or key == "source_type" and isinstance(value, str): @@ -147,11 +149,13 @@ def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: elif key == "load_time" and isinstance(value, str): coerced[key] = SeedDatasetLoadTime(value.strip().lower()) elif key == "modalities" and isinstance(value, list): - coerced[key] = [v.strip().lower() if isinstance(v, str) else v for v in value] + coerced[key] = [v.strip().lower() if isinstance(v, str) + else v for v in value] elif key == "modalities" and isinstance(value, str): coerced[key] = [value.strip().lower()] elif key == "harm_categories" and isinstance(value, list): - coerced[key] = [v.strip().lower() if isinstance(v, str) else v for v in value] + coerced[key] = [v.strip().lower() if isinstance(v, str) + else v for v in value] elif key == "harm_categories" and isinstance(value, str): coerced[key] = [value.strip().lower()] else: diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index 05ea5f86ca..3668229ee5 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -89,16 +89,20 @@ async def test_fetch_datasets_async(self): # Mock providers mock_provider1 = MagicMock(__name__="P1") mock_provider1.return_value.dataset_name = "d1" - mock_provider1.return_value._parse_metadata = AsyncMock(return_value=None) + mock_provider1.return_value._parse_metadata = AsyncMock( + return_value=None) mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset( + seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock(__name__="P2") mock_provider2.return_value.dataset_name = "d2" - mock_provider2.return_value._parse_metadata = AsyncMock(return_value=None) + mock_provider2.return_value._parse_metadata = AsyncMock( + return_value=None) mock_provider2.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") + return_value=SeedDataset( + seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): @@ -110,15 +114,19 @@ async def test_fetch_datasets_async_with_filter(self): """Test fetching datasets with filter.""" mock_provider1 = MagicMock(__name__="P1") mock_provider1.return_value.dataset_name = "d1" - mock_provider1.return_value._parse_metadata = AsyncMock(return_value=None) + mock_provider1.return_value._parse_metadata = AsyncMock( + return_value=None) mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset( + seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock(__name__="P2") mock_provider2.return_value.dataset_name = "d2" - mock_provider2.return_value._parse_metadata = AsyncMock(return_value=None) - mock_provider2.return_value.fetch_dataset = AsyncMock(side_effect=Exception("Should not be called")) + mock_provider2.return_value._parse_metadata = AsyncMock( + return_value=None) + mock_provider2.return_value.fetch_dataset = AsyncMock( + side_effect=Exception("Should not be called")) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): datasets = await SeedDatasetProvider.fetch_datasets_async(dataset_names=["d1"]) @@ -130,16 +138,20 @@ async def test_fetch_datasets_async_invalid_dataset_name(self): """Test that fetch_datasets_async raises ValueError for invalid dataset names.""" mock_provider1 = MagicMock(__name__="P1") mock_provider1.return_value.dataset_name = "d1" - mock_provider1.return_value._parse_metadata = AsyncMock(return_value=None) + mock_provider1.return_value._parse_metadata = AsyncMock( + return_value=None) mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset( + seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock(__name__="P2") mock_provider2.return_value.dataset_name = "d2" - mock_provider2.return_value._parse_metadata = AsyncMock(return_value=None) + mock_provider2.return_value._parse_metadata = AsyncMock( + return_value=None) mock_provider2.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") + return_value=SeedDataset( + seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): @@ -268,7 +280,8 @@ async def test_parse_metadata_from_class_attrs(self): assert metadata.tags == {"default", "safety"} assert metadata.size == "large" assert metadata.modalities == ["text"] - assert metadata.harm_categories == ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] + assert metadata.harm_categories == [ + "cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] # source_type is not declared as a class attribute on HarmBench; # load_time inherits the UNINITIALIZED default from SeedDatasetProvider base class assert metadata.source_type is None @@ -278,22 +291,25 @@ def test_all_tag(self): """Filter with tags={'all'} matches any metadata.""" metadata = SeedDatasetMetadata(tags={"safety"}) filters = SeedDatasetFilter(tags={"all"}) - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) def test_tags(self): """Tag filter uses set intersection.""" metadata = SeedDatasetMetadata(tags={"safety", "default"}) - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=SeedDatasetFilter(tags={"safety"})) - assert not SeedDatasetProvider._match_filter(metadata=metadata, filters=SeedDatasetFilter(tags={"unrelated"})) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=SeedDatasetFilter(tags={"safety"})) + assert not SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=SeedDatasetFilter(tags={"unrelated"})) def test_sizes(self): """Size filter checks membership in the sizes list.""" metadata = SeedDatasetMetadata(size="large") - assert SeedDatasetProvider._match_filter( + assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(sizes=["large", "huge"]), ) - assert not SeedDatasetProvider._match_filter( + assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(sizes=["small"]), ) @@ -301,11 +317,11 @@ def test_sizes(self): def test_modalities(self): """Modality filter uses set intersection.""" metadata = SeedDatasetMetadata(modalities=["text", "image"]) - assert SeedDatasetProvider._match_filter( + assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(modalities=["text"]), ) - assert not SeedDatasetProvider._match_filter( + assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(modalities=["audio"]), ) @@ -313,11 +329,11 @@ def test_modalities(self): def test_sources(self): """Source filter checks membership.""" metadata = SeedDatasetMetadata(source_type="remote") - assert SeedDatasetProvider._match_filter( + assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(source_types=["remote"]), ) - assert not SeedDatasetProvider._match_filter( + assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(source_types=["local"]), ) @@ -325,23 +341,24 @@ def test_sources(self): def test_ranks(self): """Load time filter checks membership.""" metadata = SeedDatasetMetadata(load_time=SeedDatasetLoadTime.FAST) - assert SeedDatasetProvider._match_filter( + assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(load_times=[SeedDatasetLoadTime.FAST]), ) - assert not SeedDatasetProvider._match_filter( + assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(load_times=[SeedDatasetLoadTime.SLOW]), ) def test_harm_categories(self): """Harm category filter uses set intersection.""" - metadata = SeedDatasetMetadata(harm_categories=["violence", "cybercrime"]) - assert SeedDatasetProvider._match_filter( + metadata = SeedDatasetMetadata( + harm_categories=["violence", "cybercrime"]) + assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(harm_categories=["violence"]), ) - assert not SeedDatasetProvider._match_filter( + assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(harm_categories=["unrelated"]), ) @@ -350,7 +367,8 @@ def test_empty_filter(self): """Empty filter (all None) matches any metadata.""" metadata = SeedDatasetMetadata(tags={"safety"}, size="large") filters = SeedDatasetFilter() - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_no_metadata(self): @@ -371,53 +389,68 @@ class TestStrictMatchFiltering: def test_strict_tags_all_present_matches(self): """strict_match requires ALL filter tags to be present in metadata.""" metadata = SeedDatasetMetadata(tags={"safety", "default", "curated"}) - filters = SeedDatasetFilter(tags={"safety", "default"}, strict_match=True) - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + filters = SeedDatasetFilter( + tags={"safety", "default"}, strict_match=True) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) def test_strict_tags_partial_overlap_fails(self): """strict_match rejects if metadata is missing any requested tag.""" metadata = SeedDatasetMetadata(tags={"safety"}) - filters = SeedDatasetFilter(tags={"safety", "default"}, strict_match=True) - assert not SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + filters = SeedDatasetFilter( + tags={"safety", "default"}, strict_match=True) + assert not SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) def test_nonstrict_tags_partial_overlap_passes(self): """Without strict_match, any tag overlap is sufficient.""" metadata = SeedDatasetMetadata(tags={"safety"}) - filters = SeedDatasetFilter(tags={"safety", "default"}, strict_match=False) - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + filters = SeedDatasetFilter( + tags={"safety", "default"}, strict_match=False) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) def test_strict_harm_categories_all_present_matches(self): """strict_match requires ALL filter harm_categories present in metadata.""" - metadata = SeedDatasetMetadata(harm_categories=["violence", "cybercrime", "illegal"]) - filters = SeedDatasetFilter(harm_categories=["violence", "cybercrime"], strict_match=True) - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + metadata = SeedDatasetMetadata( + harm_categories=["violence", "cybercrime", "illegal"]) + filters = SeedDatasetFilter( + harm_categories=["violence", "cybercrime"], strict_match=True) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) def test_strict_harm_categories_partial_fails(self): """strict_match rejects if metadata is missing any requested harm category.""" metadata = SeedDatasetMetadata(harm_categories=["violence"]) - filters = SeedDatasetFilter(harm_categories=["violence", "cybercrime"], strict_match=True) - assert not SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + filters = SeedDatasetFilter( + harm_categories=["violence", "cybercrime"], strict_match=True) + assert not SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) def test_strict_modalities_all_present_matches(self): """strict_match requires ALL filter modalities present in metadata.""" metadata = SeedDatasetMetadata(modalities=["text", "image", "audio"]) - filters = SeedDatasetFilter(modalities=["text", "image"], strict_match=True) - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + filters = SeedDatasetFilter( + modalities=["text", "image"], strict_match=True) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) def test_strict_modalities_partial_fails(self): """strict_match rejects if metadata is missing any requested modality.""" metadata = SeedDatasetMetadata(modalities=["text"]) - filters = SeedDatasetFilter(modalities=["text", "image"], strict_match=True) - assert not SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + filters = SeedDatasetFilter( + modalities=["text", "image"], strict_match=True) + assert not SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) def test_strict_size_unchanged(self): """strict_match doesn't change size behavior — still membership check.""" metadata = SeedDatasetMetadata(size="large") - assert SeedDatasetProvider._match_filter( + assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(sizes=["large"], strict_match=True), ) - assert not SeedDatasetProvider._match_filter( + assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(sizes=["small"], strict_match=True), ) @@ -430,7 +463,7 @@ def test_strict_cross_axis_and(self): harm_categories=["violence", "cybercrime"], ) # Both axes satisfied - assert SeedDatasetProvider._match_filter( + assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter( tags={"safety"}, @@ -439,7 +472,7 @@ def test_strict_cross_axis_and(self): ), ) # harm_categories axis fails (missing "illegal") - assert not SeedDatasetProvider._match_filter( + assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter( tags={"safety"}, @@ -452,41 +485,45 @@ def test_strict_all_tag_still_bypasses(self): """tags={'all'} still bypasses everything even with strict_match.""" metadata = SeedDatasetMetadata(tags={"safety"}) filters = SeedDatasetFilter(tags={"all"}, strict_match=True) - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) def test_strict_default_plus_other_tags_requires_both(self): """With strict_match, 'default' is a normal tag — all must be present.""" metadata = SeedDatasetMetadata(tags={"default", "safety"}) # Both present → match - assert SeedDatasetProvider._match_filter( + assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(tags={"default", "safety"}, strict_match=True), + filters=SeedDatasetFilter( + tags={"default", "safety"}, strict_match=True), ) # Missing "curated" → reject - assert not SeedDatasetProvider._match_filter( + assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(tags={"default", "safety", "curated"}, strict_match=True), + filters=SeedDatasetFilter( + tags={"default", "safety", "curated"}, strict_match=True), ) def test_nonstrict_default_is_shortcut(self): """Without strict_match, 'default' in filter tags is a shortcut match.""" # Dataset has "default" tag → matches even without other filter tags present metadata = SeedDatasetMetadata(tags={"default"}) - assert SeedDatasetProvider._match_filter( + assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(tags={"default", "nonexistent"}), ) def test_strict_default_without_tag_on_dataset_fails(self): """With strict_match, dataset must actually have 'default' in tags.""" - metadata = SeedDatasetMetadata(tags={"safety"}, load_time=SeedDatasetLoadTime.FAST) + metadata = SeedDatasetMetadata( + tags={"safety"}, load_time=SeedDatasetLoadTime.FAST) # Without strict, "default" would match via initialized load_time - assert SeedDatasetProvider._match_filter( + assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(tags={"default"}), ) # With strict, "default" must be in metadata.tags - assert not SeedDatasetProvider._match_filter( + assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(tags={"default"}, strict_match=True), ) @@ -532,7 +569,8 @@ def test_all_bypasses_match_filter_entirely(self): harm_categories=["violence"], strict_match=True, ) - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_all_includes_datasets_without_metadata(self): @@ -562,7 +600,8 @@ async def test_all_skips_match_filter_call(self): mock_cls.return_value._parse_metadata = AsyncMock(return_value=None) with ( - patch.dict(SeedDatasetProvider._registry, {"P": mock_cls}, clear=True), + patch.dict(SeedDatasetProvider._registry, + {"P": mock_cls}, clear=True), patch.object(SeedDatasetProvider, "_match_filter") as mock_match, ): await SeedDatasetProvider.get_all_dataset_names_async( @@ -628,7 +667,8 @@ async def test_all_tag(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(tags={"all"}) - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_tags(self, tmp_path): @@ -650,7 +690,8 @@ async def test_tags(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(tags={"safety"}) - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_sizes(self, tmp_path): @@ -670,7 +711,8 @@ async def test_sizes(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(sizes=["large"]) - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_modalities(self, tmp_path): @@ -691,7 +733,8 @@ async def test_modalities(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(modalities=["text"]) - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_sources(self, tmp_path): @@ -711,7 +754,8 @@ async def test_sources(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(source_types=["remote"]) - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_ranks(self, tmp_path): @@ -731,7 +775,8 @@ async def test_ranks(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(load_times=[SeedDatasetLoadTime.FAST]) - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_harm_categories(self, tmp_path): @@ -753,7 +798,8 @@ async def test_harm_categories(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(harm_categories=["violence"]) - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_empty_filter(self, tmp_path): @@ -774,7 +820,8 @@ async def test_empty_filter(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter() - assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata( + metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_no_metadata(self, tmp_path): @@ -807,7 +854,8 @@ class TestLocalDatasetMetadataCollisions: @staticmethod def _get_local_prompt_files() -> list: """Collect all .prompt and .yaml files under the local datasets directory.""" - local_dir = Path(__file__).resolve().parents[3] / "pyrit" / "datasets" / "seed_datasets" / "local" + local_dir = Path(__file__).resolve( + ).parents[3] / "pyrit" / "datasets" / "seed_datasets" / "local" return sorted(local_dir.glob("**/*.prompt")) + sorted(local_dir.glob("**/*.yaml")) @pytest.mark.parametrize("prompt_file", _get_local_prompt_files.__func__(), ids=lambda p: p.stem) @@ -836,7 +884,8 @@ def test_no_yaml_key_shadows_metadata_field_with_wrong_type(self, prompt_file): if not isinstance(data, dict): return - metadata_field_names = {fld.name for fld in dc_fields(SeedDatasetMetadata)} + metadata_field_names = { + fld.name for fld in dc_fields(SeedDatasetMetadata)} overlapping_keys = metadata_field_names & data.keys() if not overlapping_keys: From b0d01186767fdb624aa728b64387f0683e1d78f8 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 18 Mar 2026 21:41:15 +0000 Subject: [PATCH 21/29] precommit --- .../seed_datasets/seed_dataset_provider.py | 27 ++-- pyrit/datasets/seed_datasets/seed_metadata.py | 9 +- .../datasets/test_seed_dataset_provider.py | 147 ++++++------------ 3 files changed, 67 insertions(+), 116 deletions(-) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 10c977fc9e..a444f07080 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -139,8 +139,7 @@ async def get_all_dataset_names_async(cls, filters: Optional[SeedDatasetFilter] dataset_names.add(provider.dataset_name) except Exception as e: - raise ValueError( - f"Could not get dataset name from {provider_class.__name__}: {e}") from e + raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) @classmethod @@ -183,8 +182,13 @@ def _match_filter_to_metadata(cls, metadata: SeedDatasetMetadata, filters: SeedD # - With strict_match: "default" is treated like any other tag. Every # requested tag (including "default") must be present in the dataset # to count as a match. - if not filters.strict_match and filters.tags and "default" in filters.tags and \ - metadata.tags and "default" in metadata.tags: + if ( + not filters.strict_match + and filters.tags + and "default" in filters.tags + and metadata.tags + and "default" in metadata.tags + ): return True if metadata.tags and filters.tags: @@ -265,11 +269,9 @@ async def fetch_datasets_async( # Validate dataset names if specified if dataset_names is not None: available_names = await cls.get_all_dataset_names_async() - invalid_names = [ - name for name in dataset_names if name not in available_names] + invalid_names = [name for name in dataset_names if name not in available_names] if invalid_names: - raise ValueError( - f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") + raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -295,8 +297,7 @@ async def fetch_single_dataset( # Progress tracking total_count = len(cls._registry) - pbar = tqdm(total=total_count, - desc="Loading datasets - this can take a few minutes", unit="dataset") + pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -334,12 +335,10 @@ async def fetch_with_semaphore( logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] - combined_seeds = list( - existing_dataset.seeds) + list(dataset.seeds) + combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset - logger.info( - f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") + logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values()) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 0dc114faf0..0facfcc667 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -140,8 +140,7 @@ def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: coerced: dict[str, Any] = {} for key, value in raw_metadata.items(): if key == "tags" and isinstance(value, list): - coerced[key] = {v.strip().lower() if isinstance( - v, str) else v for v in value} + coerced[key] = {v.strip().lower() if isinstance(v, str) else v for v in value} elif key == "tags" and isinstance(value, str): coerced[key] = {value.strip().lower()} elif key == "size" and isinstance(value, str) or key == "source_type" and isinstance(value, str): @@ -149,13 +148,11 @@ def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: elif key == "load_time" and isinstance(value, str): coerced[key] = SeedDatasetLoadTime(value.strip().lower()) elif key == "modalities" and isinstance(value, list): - coerced[key] = [v.strip().lower() if isinstance(v, str) - else v for v in value] + coerced[key] = [v.strip().lower() if isinstance(v, str) else v for v in value] elif key == "modalities" and isinstance(value, str): coerced[key] = [value.strip().lower()] elif key == "harm_categories" and isinstance(value, list): - coerced[key] = [v.strip().lower() if isinstance(v, str) - else v for v in value] + coerced[key] = [v.strip().lower() if isinstance(v, str) else v for v in value] elif key == "harm_categories" and isinstance(value, str): coerced[key] = [value.strip().lower()] else: diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index 3668229ee5..e63ce56e9c 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -89,20 +89,16 @@ async def test_fetch_datasets_async(self): # Mock providers mock_provider1 = MagicMock(__name__="P1") mock_provider1.return_value.dataset_name = "d1" - mock_provider1.return_value._parse_metadata = AsyncMock( - return_value=None) + mock_provider1.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset( - seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock(__name__="P2") mock_provider2.return_value.dataset_name = "d2" - mock_provider2.return_value._parse_metadata = AsyncMock( - return_value=None) + mock_provider2.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider2.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset( - seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") + return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): @@ -114,19 +110,15 @@ async def test_fetch_datasets_async_with_filter(self): """Test fetching datasets with filter.""" mock_provider1 = MagicMock(__name__="P1") mock_provider1.return_value.dataset_name = "d1" - mock_provider1.return_value._parse_metadata = AsyncMock( - return_value=None) + mock_provider1.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset( - seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock(__name__="P2") mock_provider2.return_value.dataset_name = "d2" - mock_provider2.return_value._parse_metadata = AsyncMock( - return_value=None) - mock_provider2.return_value.fetch_dataset = AsyncMock( - side_effect=Exception("Should not be called")) + mock_provider2.return_value._parse_metadata = AsyncMock(return_value=None) + mock_provider2.return_value.fetch_dataset = AsyncMock(side_effect=Exception("Should not be called")) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): datasets = await SeedDatasetProvider.fetch_datasets_async(dataset_names=["d1"]) @@ -138,20 +130,16 @@ async def test_fetch_datasets_async_invalid_dataset_name(self): """Test that fetch_datasets_async raises ValueError for invalid dataset names.""" mock_provider1 = MagicMock(__name__="P1") mock_provider1.return_value.dataset_name = "d1" - mock_provider1.return_value._parse_metadata = AsyncMock( - return_value=None) + mock_provider1.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset( - seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock(__name__="P2") mock_provider2.return_value.dataset_name = "d2" - mock_provider2.return_value._parse_metadata = AsyncMock( - return_value=None) + mock_provider2.return_value._parse_metadata = AsyncMock(return_value=None) mock_provider2.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset( - seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") + return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): @@ -280,8 +268,7 @@ async def test_parse_metadata_from_class_attrs(self): assert metadata.tags == {"default", "safety"} assert metadata.size == "large" assert metadata.modalities == ["text"] - assert metadata.harm_categories == [ - "cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] + assert metadata.harm_categories == ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] # source_type is not declared as a class attribute on HarmBench; # load_time inherits the UNINITIALIZED default from SeedDatasetProvider base class assert metadata.source_type is None @@ -291,16 +278,17 @@ def test_all_tag(self): """Filter with tags={'all'} matches any metadata.""" metadata = SeedDatasetMetadata(tags={"safety"}) filters = SeedDatasetFilter(tags={"all"}) - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) def test_tags(self): """Tag filter uses set intersection.""" metadata = SeedDatasetMetadata(tags={"safety", "default"}) assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=SeedDatasetFilter(tags={"safety"})) + metadata=metadata, filters=SeedDatasetFilter(tags={"safety"}) + ) assert not SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=SeedDatasetFilter(tags={"unrelated"})) + metadata=metadata, filters=SeedDatasetFilter(tags={"unrelated"}) + ) def test_sizes(self): """Size filter checks membership in the sizes list.""" @@ -352,8 +340,7 @@ def test_ranks(self): def test_harm_categories(self): """Harm category filter uses set intersection.""" - metadata = SeedDatasetMetadata( - harm_categories=["violence", "cybercrime"]) + metadata = SeedDatasetMetadata(harm_categories=["violence", "cybercrime"]) assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter(harm_categories=["violence"]), @@ -367,8 +354,7 @@ def test_empty_filter(self): """Empty filter (all None) matches any metadata.""" metadata = SeedDatasetMetadata(tags={"safety"}, size="large") filters = SeedDatasetFilter() - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_no_metadata(self): @@ -389,59 +375,44 @@ class TestStrictMatchFiltering: def test_strict_tags_all_present_matches(self): """strict_match requires ALL filter tags to be present in metadata.""" metadata = SeedDatasetMetadata(tags={"safety", "default", "curated"}) - filters = SeedDatasetFilter( - tags={"safety", "default"}, strict_match=True) - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + filters = SeedDatasetFilter(tags={"safety", "default"}, strict_match=True) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) def test_strict_tags_partial_overlap_fails(self): """strict_match rejects if metadata is missing any requested tag.""" metadata = SeedDatasetMetadata(tags={"safety"}) - filters = SeedDatasetFilter( - tags={"safety", "default"}, strict_match=True) - assert not SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + filters = SeedDatasetFilter(tags={"safety", "default"}, strict_match=True) + assert not SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) def test_nonstrict_tags_partial_overlap_passes(self): """Without strict_match, any tag overlap is sufficient.""" metadata = SeedDatasetMetadata(tags={"safety"}) - filters = SeedDatasetFilter( - tags={"safety", "default"}, strict_match=False) - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + filters = SeedDatasetFilter(tags={"safety", "default"}, strict_match=False) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) def test_strict_harm_categories_all_present_matches(self): """strict_match requires ALL filter harm_categories present in metadata.""" - metadata = SeedDatasetMetadata( - harm_categories=["violence", "cybercrime", "illegal"]) - filters = SeedDatasetFilter( - harm_categories=["violence", "cybercrime"], strict_match=True) - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + metadata = SeedDatasetMetadata(harm_categories=["violence", "cybercrime", "illegal"]) + filters = SeedDatasetFilter(harm_categories=["violence", "cybercrime"], strict_match=True) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) def test_strict_harm_categories_partial_fails(self): """strict_match rejects if metadata is missing any requested harm category.""" metadata = SeedDatasetMetadata(harm_categories=["violence"]) - filters = SeedDatasetFilter( - harm_categories=["violence", "cybercrime"], strict_match=True) - assert not SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + filters = SeedDatasetFilter(harm_categories=["violence", "cybercrime"], strict_match=True) + assert not SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) def test_strict_modalities_all_present_matches(self): """strict_match requires ALL filter modalities present in metadata.""" metadata = SeedDatasetMetadata(modalities=["text", "image", "audio"]) - filters = SeedDatasetFilter( - modalities=["text", "image"], strict_match=True) - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + filters = SeedDatasetFilter(modalities=["text", "image"], strict_match=True) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) def test_strict_modalities_partial_fails(self): """strict_match rejects if metadata is missing any requested modality.""" metadata = SeedDatasetMetadata(modalities=["text"]) - filters = SeedDatasetFilter( - modalities=["text", "image"], strict_match=True) - assert not SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + filters = SeedDatasetFilter(modalities=["text", "image"], strict_match=True) + assert not SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) def test_strict_size_unchanged(self): """strict_match doesn't change size behavior — still membership check.""" @@ -485,8 +456,7 @@ def test_strict_all_tag_still_bypasses(self): """tags={'all'} still bypasses everything even with strict_match.""" metadata = SeedDatasetMetadata(tags={"safety"}) filters = SeedDatasetFilter(tags={"all"}, strict_match=True) - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) def test_strict_default_plus_other_tags_requires_both(self): """With strict_match, 'default' is a normal tag — all must be present.""" @@ -494,14 +464,12 @@ def test_strict_default_plus_other_tags_requires_both(self): # Both present → match assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter( - tags={"default", "safety"}, strict_match=True), + filters=SeedDatasetFilter(tags={"default", "safety"}, strict_match=True), ) # Missing "curated" → reject assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter( - tags={"default", "safety", "curated"}, strict_match=True), + filters=SeedDatasetFilter(tags={"default", "safety", "curated"}, strict_match=True), ) def test_nonstrict_default_is_shortcut(self): @@ -515,8 +483,7 @@ def test_nonstrict_default_is_shortcut(self): def test_strict_default_without_tag_on_dataset_fails(self): """With strict_match, dataset must actually have 'default' in tags.""" - metadata = SeedDatasetMetadata( - tags={"safety"}, load_time=SeedDatasetLoadTime.FAST) + metadata = SeedDatasetMetadata(tags={"safety"}, load_time=SeedDatasetLoadTime.FAST) # Without strict, "default" would match via initialized load_time assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, @@ -569,8 +536,7 @@ def test_all_bypasses_match_filter_entirely(self): harm_categories=["violence"], strict_match=True, ) - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_all_includes_datasets_without_metadata(self): @@ -600,8 +566,7 @@ async def test_all_skips_match_filter_call(self): mock_cls.return_value._parse_metadata = AsyncMock(return_value=None) with ( - patch.dict(SeedDatasetProvider._registry, - {"P": mock_cls}, clear=True), + patch.dict(SeedDatasetProvider._registry, {"P": mock_cls}, clear=True), patch.object(SeedDatasetProvider, "_match_filter") as mock_match, ): await SeedDatasetProvider.get_all_dataset_names_async( @@ -667,8 +632,7 @@ async def test_all_tag(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(tags={"all"}) - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_tags(self, tmp_path): @@ -690,8 +654,7 @@ async def test_tags(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(tags={"safety"}) - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_sizes(self, tmp_path): @@ -711,8 +674,7 @@ async def test_sizes(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(sizes=["large"]) - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_modalities(self, tmp_path): @@ -733,8 +695,7 @@ async def test_modalities(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(modalities=["text"]) - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_sources(self, tmp_path): @@ -754,8 +715,7 @@ async def test_sources(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(source_types=["remote"]) - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_ranks(self, tmp_path): @@ -775,8 +735,7 @@ async def test_ranks(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(load_times=[SeedDatasetLoadTime.FAST]) - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_harm_categories(self, tmp_path): @@ -798,8 +757,7 @@ async def test_harm_categories(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(harm_categories=["violence"]) - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_empty_filter(self, tmp_path): @@ -820,8 +778,7 @@ async def test_empty_filter(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter() - assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @pytest.mark.asyncio async def test_no_metadata(self, tmp_path): @@ -854,8 +811,7 @@ class TestLocalDatasetMetadataCollisions: @staticmethod def _get_local_prompt_files() -> list: """Collect all .prompt and .yaml files under the local datasets directory.""" - local_dir = Path(__file__).resolve( - ).parents[3] / "pyrit" / "datasets" / "seed_datasets" / "local" + local_dir = Path(__file__).resolve().parents[3] / "pyrit" / "datasets" / "seed_datasets" / "local" return sorted(local_dir.glob("**/*.prompt")) + sorted(local_dir.glob("**/*.yaml")) @pytest.mark.parametrize("prompt_file", _get_local_prompt_files.__func__(), ids=lambda p: p.stem) @@ -884,8 +840,7 @@ def test_no_yaml_key_shadows_metadata_field_with_wrong_type(self, prompt_file): if not isinstance(data, dict): return - metadata_field_names = { - fld.name for fld in dc_fields(SeedDatasetMetadata)} + metadata_field_names = {fld.name for fld in dc_fields(SeedDatasetMetadata)} overlapping_keys = metadata_field_names & data.keys() if not overlapping_keys: From 294ab317cf2f013cdb2ebc2cefcf917cc845c1c6 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 18 Mar 2026 22:40:29 +0000 Subject: [PATCH 22/29] . --- .../datasets/seed_datasets/seed_dataset_provider.py | 4 ++-- tests/unit/datasets/test_seed_dataset_provider.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index a444f07080..908e50d47d 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -193,8 +193,8 @@ def _match_filter_to_metadata(cls, metadata: SeedDatasetMetadata, filters: SeedD if metadata.tags and filters.tags: if filters.strict_match: - # All requested tags must be present in the dataset - if filters.tags > metadata.tags: + # Reject if any requested tag is missing from the dataset + if filters.tags - metadata.tags: return False elif not (filters.tags & metadata.tags): return False diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index e63ce56e9c..a430020dcc 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -483,16 +483,16 @@ def test_nonstrict_default_is_shortcut(self): def test_strict_default_without_tag_on_dataset_fails(self): """With strict_match, dataset must actually have 'default' in tags.""" - metadata = SeedDatasetMetadata(tags={"safety"}, load_time=SeedDatasetLoadTime.FAST) - # Without strict, "default" would match via initialized load_time + metadata = SeedDatasetMetadata(tags={"default", "safety"}, load_time=SeedDatasetLoadTime.FAST) + # Without strict, "default" shortcut matches because metadata has "default" tag assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(tags={"default"}), + filters=SeedDatasetFilter(tags={"default", "curated"}), ) - # With strict, "default" must be in metadata.tags + # With strict, ALL filter tags must be in metadata — "curated" is missing assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(tags={"default"}, strict_match=True), + filters=SeedDatasetFilter(tags={"default", "curated"}, strict_match=True), ) @@ -567,7 +567,7 @@ async def test_all_skips_match_filter_call(self): with ( patch.dict(SeedDatasetProvider._registry, {"P": mock_cls}, clear=True), - patch.object(SeedDatasetProvider, "_match_filter") as mock_match, + patch.object(SeedDatasetProvider, "_match_filter_to_metadata") as mock_match, ): await SeedDatasetProvider.get_all_dataset_names_async( filters=SeedDatasetFilter(tags={"all"}), From 7c4b8d7ea6d931cee2917e821cab4ce59d0380ab Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 18 Mar 2026 23:25:15 +0000 Subject: [PATCH 23/29] refactor --- .../local/local_dataset_loader.py | 6 +- .../remote/remote_dataset_loader.py | 26 +- .../seed_datasets/seed_dataset_provider.py | 118 ++++---- pyrit/datasets/seed_datasets/seed_metadata.py | 285 +++++++++++------- .../datasets/test_seed_dataset_metadata.py | 210 +++++++------ .../datasets/test_seed_dataset_provider.py | 100 +++--- 6 files changed, 416 insertions(+), 329 deletions(-) diff --git a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py index 58064d81e1..a78491509c 100644 --- a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py @@ -105,7 +105,11 @@ async def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: return None coerced = SeedDatasetMetadata._coerce_metadata_values(raw_metadata=raw) - return SeedDatasetMetadata(**coerced) + result = SeedDatasetMetadata(**coerced) + # Validation after coercion: raw values are strings/lists, not sets. + # _validate_singular_fields needs sets to check cardinality. + SeedDatasetMetadata._validate_singular_fields(metadata=result) + return result def _register_local_datasets() -> None: diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index b1110db6a4..944dd513f3 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -290,7 +290,10 @@ def _load_dataset_sync() -> Any: async def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: """ - Extract metadata from class attributes and format into SeedDatasetMetadata schema. + Extract metadata from class attributes, wrap in sets, and format into SeedDatasetMetadata. + + Class attributes may be singular values (str, enum), lists, or sets. + All are normalized into sets for the unified SeedDatasetMetadata schema. Returns: Optional[SeedDatasetMetadata]: Parsed metadata if available, otherwise None. @@ -298,10 +301,19 @@ async def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: valid_fields = [f.name for f in fields(SeedDatasetMetadata)] provider_class = type(self) - self_metadata = { - key: getattr(provider_class, key) for key in valid_fields if getattr(provider_class, key, None) is not None - } - - if not self_metadata: + raw = {} + for key in valid_fields: + value = getattr(provider_class, key, None) + if value is None: + continue + raw[key] = value + + if not raw: return None - return SeedDatasetMetadata(**self_metadata) + + coerced = SeedDatasetMetadata._coerce_metadata_values(raw_metadata=raw) + # Validation must happen after coercion because raw values are strings/lists, + # not sets. _validate_singular_fields checks set cardinality (len > 1). + result = SeedDatasetMetadata(**coerced) + SeedDatasetMetadata._validate_singular_fields(metadata=result) + return result diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 908e50d47d..e37f207666 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -5,6 +5,7 @@ import inspect import logging from abc import ABC, abstractmethod +from dataclasses import fields as dc_fields from typing import Any, Optional from tqdm import tqdm @@ -125,7 +126,7 @@ async def get_all_dataset_names_async(cls, filters: Optional[SeedDatasetFilter] if filters: # "all" bypasses metadata filtering and returns every dataset - if filters.tags and "all" in filters.tags: + if filters.has_all_tag: dataset_names.add(provider.dataset_name) continue @@ -145,87 +146,80 @@ async def get_all_dataset_names_async(cls, filters: Optional[SeedDatasetFilter] @classmethod def _match_filter_to_metadata(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter) -> bool: """ - Match filters against provider metadata. The filter is used to determine whether a match is found - against the metadat object. + Match a dataset's metadata against filter criteria. - When strict_match is False (default): - - Across dimensions (e.g. size + harm_categories): AND — all specified conditions must match. - - Within a dimension (e.g. sizes=["small", "large"]): OR — metadata needs to overlap with - at least one value. - - When strict_match is True: - - Across dimensions: AND (same as default). - - Within set-like dimensions (tags, harm_categories, modalities): AND — metadata must - contain ALL requested values, not just one. - - Within scalar dimensions (size, source_type, load_time): unchanged (membership check). + A dataset matches if ANY criterion in filters.criteria matches (OR across + criteria). Within each criterion, ALL specified fields must match (AND + across fields). Within each field: + - strict_match=False: any overlap suffices (set intersection) + - strict_match=True: all filter values must be present (filter is subset) Special tags: - - "all": bypasses all filtering, returns every dataset. - - "default": matches datasets that have tagged themselves as part of the curated set. + - "all": bypasses all filtering, returns True immediately. + - "default": without strict_match, matches if the dataset has "default" tag. Args: - metadata (SeedDatasetMetadata): The metadata object extracted from the SeedDatasetProvider - subclass. - filters (SeedDatasetFilter): The filter object provided by the user to get_all_dataset_names_async. + metadata: The dataset's metadata. + filters: The user-provided filter. Returns: - bool: Whether the filters match. + Whether the metadata matches any criterion. """ - # Tags - # "all" always bypasses all filtering, regardless of strict_match. - if filters.tags and "all" in filters.tags: + # "all" always bypasses + if filters.has_all_tag: return True - # "default" tag handling depends on strict_match: - # - Without strict_match: "default" alone is enough to match if the dataset - # has a "default" tag or an initialized load_time. - # - With strict_match: "default" is treated like any other tag. Every - # requested tag (including "default") must be present in the dataset - # to count as a match. + return any( + cls._match_single_criterion(metadata=metadata, criterion=c, strict_match=filters.strict_match) + for c in filters.criteria + ) + + @classmethod + def _match_single_criterion( + cls, + *, + metadata: SeedDatasetMetadata, + criterion: SeedDatasetMetadata, + strict_match: bool, + ) -> bool: + """ + Match a single SeedDatasetMetadata criterion against dataset metadata. + + Args: + metadata: The dataset's real metadata. + criterion: A single filter criterion. + strict_match: Whether to require all filter values (AND) vs any overlap (OR). + + Returns: + Whether the metadata satisfies this criterion. + """ + # "default" shortcut (only without strict_match): + # When the filter asks for "default" and the dataset has "default" in its + # tags, match immediately. This lets "default" act as a curated-set marker + # that bypasses other filter axes. With strict_match, "default" is treated + # as a normal tag and must satisfy the full subset check. if ( - not filters.strict_match - and filters.tags - and "default" in filters.tags + not strict_match + and criterion.tags + and "default" in criterion.tags and metadata.tags and "default" in metadata.tags ): return True - if metadata.tags and filters.tags: - if filters.strict_match: - # Reject if any requested tag is missing from the dataset - if filters.tags - metadata.tags: - return False - elif not (filters.tags & metadata.tags): - return False + for field in dc_fields(SeedDatasetMetadata): + filter_vals = getattr(criterion, field.name) + meta_vals = getattr(metadata, field.name) - # Size - if metadata.size and filters.sizes and metadata.size not in filters.sizes: - return False + if filter_vals is None or meta_vals is None: + continue - # Harm Categories - if metadata.harm_categories and filters.harm_categories: - if filters.strict_match: - if not set(filters.harm_categories) <= set(metadata.harm_categories): + if strict_match: + if filter_vals - meta_vals: return False - elif not set(metadata.harm_categories) & set(filters.harm_categories): - return False - - # Source Type - if metadata.source_type and filters.source_types and metadata.source_type not in filters.source_types: - return False - - # Modalities - if metadata.modalities and filters.modalities: - if filters.strict_match: - if not set(filters.modalities) <= set(metadata.modalities): + else: + if not (filter_vals & meta_vals): return False - elif not set(metadata.modalities) & set(filters.modalities): - return False - - # Load Time - if metadata.load_time and filters.load_times and metadata.load_time not in filters.load_times: # noqa: SIM103 - return False return True diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 0facfcc667..369b8a0c74 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -2,11 +2,9 @@ # Licensed under the MIT license. import logging -from dataclasses import dataclass +from dataclasses import dataclass, fields from enum import Enum -from typing import Any, Literal, Optional - -from pyrit.models.literals import PromptDataType +from typing import Any, ClassVar, Literal, Optional logger = logging.getLogger(__name__) @@ -14,20 +12,15 @@ """ Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). -SeedDatasetMetadata is the internal schema used to normalize metadata fields -from different sources: -- Remote providers that declare metadata as class attributes -- Local prompt files that store metadata at the top level - -SeedDatasetFilter is the user-facing filter schema consumed by -SeedDatasetProvider.get_all_dataset_names_async(). +SeedDatasetMetadata is the canonical schema for dataset metadata. All fields are +optional sets. Parsers wrap singular values into single-element sets; filters may +have multiple values per field for OR matching. -Size and modality are string literals rather than enums for usability — callers -don't need to import extra types to construct a filter. +SeedDatasetFilter accepts either flat kwargs (simple use) or a list of +SeedDatasetMetadata criteria (composable use). Internally it always stores +criteria as list[SeedDatasetMetadata]. """ -# Documented expected values for string-typed metadata fields. -# These are not enforced at runtime but serve as documentation. SeedDatasetSizeCategory = Literal["tiny", "small", "medium", "large", "huge"] # tiny (<10), small (10-99), medium (100-499), large (500-4999), huge (5000+) @@ -35,130 +28,202 @@ class SeedDatasetLoadTime(Enum): - """ - Approximate time to load a dataset. Used to skip slow datasets in fast runs. - """ + """Approximate time to load a dataset. Used to skip slow datasets in fast runs.""" FAST = "fast" NORMAL = "normal" SLOW = "slow" - - # Default value for datasets whose load time hasn't been measured. UNINITIALIZED = "uninitialized" -@dataclass -class SeedDatasetFilter: - """ - - Filter object for datasets. Passed to `get_all_dataset_names_async` in - SeedDatasetProvider. - - Most fields are optional. None means "don't filter on this axis." - - By default, filtering is AND across categories (all must match) and - OR within categories (any overlap is sufficient). - - Setting strict_match=True changes within-category behavior to AND - for set-like fields (tags, harm_categories, modalities). - - Special tag behavior: - - "all": A magic bypass that returns every discoverable dataset. When "all" - is present, ALL other filter fields and strict_match are ignored. This - operates at the get_all_dataset_names_async level — _match_filter is not - even called for datasets without metadata. - - "default": Matches datasets that have "default" in their tags or have an - initialized load_time. With strict_match=True, "default" loses its - special shortcut behavior and is treated as a normal tag. - """ - - tags: Optional[set[str]] = None - sizes: Optional[list[str]] = None - modalities: Optional[list[PromptDataType]] = None - source_types: Optional[list[SeedDatasetSourceType]] = None - load_times: Optional[list[SeedDatasetLoadTime]] = None - harm_categories: Optional[list[str]] = None - - # Setting this to True forces AND-wise filtering within set-like categories. - # "all" tag still bypasses everything regardless of this flag. - strict_match: bool = False - - def __post_init__(self) -> None: - """Validate filter configuration.""" - if self.tags and "all" in self.tags and len(self.tags) > 1: - logger.warning( - "Filter has 'all' combined with other tags %s. " - "'all' bypasses all filtering — other tags will be ignored.", - self.tags - {"all"}, - ) - if self.tags and "all" in self.tags and self.strict_match: - logger.warning( - "Filter has 'all' with strict_match=True. 'all' bypasses all filtering — strict_match has no effect." - ) - if ( - self.tags - and "all" in self.tags - and any( - f is not None - for f in [self.sizes, self.modalities, self.source_types, self.load_times, self.harm_categories] - ) - ): - logger.warning( - "Filter has 'all' combined with other filter fields. " - "'all' bypasses all filtering — other fields will be ignored." - ) - - @dataclass(frozen=True) class SeedDatasetMetadata: """ - Internal schema for dataset metadata. Constructed by _parse_metadata() - implementations on each provider type. + Unified schema for dataset metadata and filter criteria. + + All fields are optional sets. When used for real dataset metadata, parsers + wrap singular values into single-element sets. When used as filter criteria, + multiple values per field express "match any of these" (OR within axis). """ + # All fields are optional sets to support both real metadata (single-element) + # and filter criteria (multi-element). SINGULAR_FIELDS enforces that parsers + # only produce single-element sets for size and source_type. tags: Optional[set[str]] = None - size: Optional[SeedDatasetSizeCategory] = None - modalities: Optional[list[PromptDataType]] = None - source_type: Optional[SeedDatasetSourceType] = None - load_time: SeedDatasetLoadTime = SeedDatasetLoadTime.UNINITIALIZED - harm_categories: Optional[list[str]] = None + size: Optional[set[str]] = None + modalities: Optional[set[str]] = None + source_type: Optional[set[str]] = None + load_time: Optional[set[SeedDatasetLoadTime]] = None + harm_categories: Optional[set[str]] = None + + # Fields that must have at most 1 element in real dataset metadata. + SINGULAR_FIELDS: ClassVar[frozenset[str]] = frozenset({"size", "source_type"}) @staticmethod def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: """ - Convert YAML primitive values into the types expected by SeedDatasetMetadata. + Convert raw values (from YAML or class attributes) into sets for SeedDatasetMetadata. - Applies .lower().strip() normalization to string values for tags, load_time, size, modalities, - source_type, and harm_categories to prevent case/whitespace mismatches. + Applies .lower().strip() normalization to all string values. Handles str, + list, set inputs for all fields, plus SeedDatasetLoadTime enum for load_time. Args: - raw_metadata (dict[str, Any]): Dictionary of field names to raw YAML-parsed values. + raw_metadata: Dictionary of field names to raw values. Returns: - dict[str, Any]: Dictionary with values coerced to the correct types. + Dictionary with all values coerced to sets. """ coerced: dict[str, Any] = {} for key, value in raw_metadata.items(): - if key == "tags" and isinstance(value, list): + if key == "load_time": + if isinstance(value, str): + coerced[key] = {SeedDatasetLoadTime(value.strip().lower())} + elif isinstance(value, SeedDatasetLoadTime): + coerced[key] = {value} + else: + logger.warning( + f"Skipping metadata field '{key}' with unexpected type " + f"{type(value).__name__} (value: {value!r})" + ) + elif isinstance(value, (list, set)): coerced[key] = {v.strip().lower() if isinstance(v, str) else v for v in value} - elif key == "tags" and isinstance(value, str): + elif isinstance(value, str): coerced[key] = {value.strip().lower()} - elif key == "size" and isinstance(value, str) or key == "source_type" and isinstance(value, str): - coerced[key] = value.strip().lower() - elif key == "load_time" and isinstance(value, str): - coerced[key] = SeedDatasetLoadTime(value.strip().lower()) - elif key == "modalities" and isinstance(value, list): - coerced[key] = [v.strip().lower() if isinstance(v, str) else v for v in value] - elif key == "modalities" and isinstance(value, str): - coerced[key] = [value.strip().lower()] - elif key == "harm_categories" and isinstance(value, list): - coerced[key] = [v.strip().lower() if isinstance(v, str) else v for v in value] - elif key == "harm_categories" and isinstance(value, str): - coerced[key] = [value.strip().lower()] else: - # Unexpected type for a metadata field — skip it with a warning - # rather than passing garbage into SeedDatasetMetadata. logger.warning( f"Skipping metadata field '{key}' with unexpected type {type(value).__name__} (value: {value!r})" ) return coerced + + @staticmethod + def _validate_singular_fields(*, metadata: "SeedDatasetMetadata") -> None: + """ + Validate that singular fields have at most 1 element. + + Call this from parsers when constructing real dataset metadata, NOT when + constructing filter criteria where multiple values are valid. + + Raises: + ValueError: If a singular field has more than 1 element. + """ + for field_name in SeedDatasetMetadata.SINGULAR_FIELDS: + value = getattr(metadata, field_name) + if value is not None and len(value) > 1: + raise ValueError( + f"Metadata field '{field_name}' must have at most 1 value " + f"for real dataset metadata, got {len(value)}: {value}" + ) + + +class SeedDatasetFilter: + """ + Filter for discovering datasets by metadata criteria. + + Supports two construction patterns: + + Simple (flat kwargs — wraps into a single SeedDatasetMetadata criterion):: + + f = SeedDatasetFilter(tags={"safety"}, size={"small", "large"}) + + Composed (explicit criteria list — OR across criteria, AND within each):: + + f = SeedDatasetFilter(criteria=[ + SeedDatasetMetadata(size={"small"}, modalities={"text"}), + SeedDatasetMetadata(size={"large"}, modalities={"image"}), + ]) + + Passing both flat kwargs and criteria raises ValueError. + + Special tags: + - "all": Magic bypass. Returns every dataset, ignores all other fields. + - "default": Matches datasets with "default" in their tags. With + strict_match=True, loses its shortcut and is treated as a normal tag. + + Args: + criteria: Explicit list of SeedDatasetMetadata to OR-match against. + strict_match: If True, within-axis matching uses AND (all filter values + must be present) instead of OR (any overlap suffices). + **kwargs: Flat metadata fields (tags, size, modalities, etc.) for simple use. + """ + + def __init__( + self, + *, + criteria: Optional[list[SeedDatasetMetadata]] = None, + strict_match: bool = False, + **kwargs: Any, + ) -> None: + """ + Construct a filter from flat metadata kwargs or an explicit criteria list. + + Simple usage (flat kwargs — wraps into a single SeedDatasetMetadata):: + + f = SeedDatasetFilter(tags={"safety"}, size={"small", "large"}) + + Composed usage (explicit criteria — OR across criteria, AND within each):: + + f = SeedDatasetFilter(criteria=[ + SeedDatasetMetadata(size={"small"}, modalities={"text"}), + SeedDatasetMetadata(size={"large"}, modalities={"image"}), + ]) + + Args: + criteria: Explicit list of SeedDatasetMetadata to OR-match against. + strict_match: If True, within-axis matching uses AND instead of OR. + **kwargs: Flat metadata fields passed to SeedDatasetMetadata. + + Raises: + ValueError: If both criteria and flat kwargs are provided. + """ + if criteria is not None and kwargs: + raise ValueError("Cannot pass both 'criteria' and flat metadata kwargs. Use one or the other.") + + if criteria is not None: + self.criteria = criteria + elif kwargs: + self.criteria = [SeedDatasetMetadata(**kwargs)] + else: + self.criteria = [SeedDatasetMetadata()] + + self.strict_match = strict_match + self._validate() + + def _validate(self) -> None: + """Warn about contradictory filter configurations.""" + if not self.has_all_tag: + return + + all_criterion = next(c for c in self.criteria if c.tags and "all" in c.tags) + + # Only lowercase "all" is accepted; reject "All", "ALL", etc. + non_lowercase = {t for t in all_criterion.tags if t != "all" and t.lower() == "all"} + if non_lowercase: + logger.warning( + "Filter has non-lowercase 'all' variants %s. Only lowercase 'all' is recognized as the bypass tag.", + non_lowercase, + ) + + if all_criterion.tags and len(all_criterion.tags) > 1: + logger.warning( + "Filter has 'all' combined with other tags %s. " + "'all' bypasses all filtering — other tags will be ignored.", + all_criterion.tags - {"all"}, + ) + if self.strict_match: + logger.warning( + "Filter has 'all' with strict_match=True. 'all' bypasses all filtering — strict_match has no effect." + ) + other_fields = [ + f.name for f in fields(all_criterion) if f.name != "tags" and getattr(all_criterion, f.name) is not None + ] + if other_fields: + logger.warning( + "Filter has 'all' combined with other filter fields %s. " + "'all' bypasses all filtering — other fields will be ignored.", + other_fields, + ) + + @property + def has_all_tag(self) -> bool: + """True if any criterion has the 'all' tag.""" + return any(c.tags and "all" in c.tags for c in self.criteria) diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py index f3635b7570..b5cec477c8 100644 --- a/tests/unit/datasets/test_seed_dataset_metadata.py +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -5,6 +5,8 @@ Tests for metadata components related to SeedDatasetProvider. """ +import pytest + from pyrit.datasets.seed_datasets.seed_metadata import ( SeedDatasetFilter, SeedDatasetLoadTime, @@ -13,10 +15,7 @@ class TestMetadataLifecycle: - """ - Test that the metadata object can be created with different - subsets of values. - """ + """Test that the metadata object can be created with different subsets of values.""" def test_has_no_values(self): metadata = SeedDatasetMetadata() @@ -24,115 +23,100 @@ def test_has_no_values(self): assert metadata.size is None assert metadata.modalities is None assert metadata.source_type is None - assert metadata.load_time == SeedDatasetLoadTime.UNINITIALIZED + assert metadata.load_time is None assert metadata.harm_categories is None def test_has_some_values(self): - metadata = SeedDatasetMetadata(tags={"safety"}, size="large") + metadata = SeedDatasetMetadata(tags={"safety"}, size={"large"}) assert metadata.tags == {"safety"} - assert metadata.size == "large" + assert metadata.size == {"large"} assert metadata.modalities is None - assert metadata.source_type is None - assert metadata.load_time == SeedDatasetLoadTime.UNINITIALIZED - assert metadata.harm_categories is None def test_has_all_values(self): metadata = SeedDatasetMetadata( tags={"default", "safety"}, - size="medium", - modalities=["text", "image"], - source_type="remote", - load_time=SeedDatasetLoadTime.FAST, - harm_categories=["violence", "illegal"], + size={"medium"}, + modalities={"text", "image"}, + source_type={"remote"}, + load_time={SeedDatasetLoadTime.FAST}, + harm_categories={"violence", "illegal"}, ) assert metadata.tags == {"default", "safety"} - assert metadata.size == "medium" + assert metadata.size == {"medium"} assert len(metadata.modalities) == 2 - assert metadata.source_type == "remote" - assert metadata.load_time == SeedDatasetLoadTime.FAST - assert metadata.harm_categories == ["violence", "illegal"] + assert metadata.source_type == {"remote"} + assert SeedDatasetLoadTime.FAST in metadata.load_time + assert metadata.harm_categories == {"violence", "illegal"} class TestFilterLifecycle: - """ - Test that the filter object can be created with different - subsets of values. - """ + """Test that the filter object wraps metadata correctly.""" def test_has_no_values(self): f = SeedDatasetFilter() - assert f.tags is None - assert f.sizes is None - assert f.modalities is None - assert f.source_types is None - assert f.load_times is None - assert f.harm_categories is None + c = f.criteria[0] + assert c.tags is None + assert c.size is None def test_has_some_values(self): - f = SeedDatasetFilter(sizes=["large"]) - assert f.sizes == ["large"] - assert f.tags is None - assert f.modalities is None + f = SeedDatasetFilter(size={"large"}) + assert f.criteria[0].size == {"large"} + assert f.criteria[0].tags is None def test_has_all_values(self): f = SeedDatasetFilter( tags={"default"}, - sizes=["small", "medium"], - modalities=["text"], - source_types=["remote"], - load_times=[SeedDatasetLoadTime.FAST], - harm_categories=["violence"], + size={"small", "medium"}, + modalities={"text"}, + source_type={"remote"}, + load_time={SeedDatasetLoadTime.FAST}, + harm_categories={"violence"}, ) - assert f.tags == {"default"} - assert len(f.sizes) == 2 - assert f.modalities == ["text"] - assert f.source_types == ["remote"] - assert f.load_times == [SeedDatasetLoadTime.FAST] - assert f.harm_categories == ["violence"] + c = f.criteria[0] + assert c.tags == {"default"} + assert len(c.size) == 2 + assert c.modalities == {"text"} + + def test_filter_allows_multiple_sizes(self): + """Filters can have multiple values for singular fields like size.""" + f = SeedDatasetFilter(size={"small", "medium", "large"}) + assert len(f.criteria[0].size) == 3 class TestMetadataProperties: - """ - Test that the metadata fields populate correctly. - """ + """Test that the metadata fields populate correctly.""" def test_size_value(self): for size in ["tiny", "small", "medium", "large", "huge"]: - metadata = SeedDatasetMetadata(size=size) - assert metadata.size == size + metadata = SeedDatasetMetadata(size={size}) + assert size in metadata.size def test_load_time_value(self): for lt in SeedDatasetLoadTime: - metadata = SeedDatasetMetadata(load_time=lt) - assert metadata.load_time == lt + metadata = SeedDatasetMetadata(load_time={lt}) + assert lt in metadata.load_time def test_source_value(self): for source_type in ["remote", "local"]: - metadata = SeedDatasetMetadata(source_type=source_type) - assert metadata.source_type == source_type + metadata = SeedDatasetMetadata(source_type={source_type}) + assert source_type in metadata.source_type def test_modality_value(self): for modality in ["text", "image", "video", "audio"]: - metadata = SeedDatasetMetadata(modalities=[modality]) + metadata = SeedDatasetMetadata(modalities={modality}) assert modality in metadata.modalities def test_tags_value(self): metadata = SeedDatasetMetadata(tags={"safety", "default", "custom"}) assert "safety" in metadata.tags - assert "default" in metadata.tags - assert "custom" in metadata.tags def test_harm_categories_value(self): - metadata = SeedDatasetMetadata(harm_categories=["violence", "cybercrime"]) + metadata = SeedDatasetMetadata(harm_categories={"violence", "cybercrime"}) assert "violence" in metadata.harm_categories - assert "cybercrime" in metadata.harm_categories class TestMetadataCoercion: - """ - Test that _coerce_metadata_values correctly normalizes raw YAML - values into the types expected by SeedDatasetMetadata. - """ + """Test that _coerce_metadata_values normalizes raw values into sets.""" def test_tags_list_coerced_to_set(self): result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"tags": ["safety", "default"]}) @@ -142,84 +126,112 @@ def test_tags_list_coerced_to_set(self): def test_tags_string_coerced_to_set(self): result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"tags": "safety"}) assert result["tags"] == {"safety"} - assert isinstance(result["tags"], set) def test_tags_normalized_lower_strip(self): result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"tags": [" Safety ", " DEFAULT"]}) assert result["tags"] == {"safety", "default"} - def test_size_coerced_to_lowercase_string(self): + def test_size_coerced_to_set(self): result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"size": " Large "}) - assert result["size"] == "large" + assert result["size"] == {"large"} + assert isinstance(result["size"], set) - def test_source_type_coerced_to_lowercase_string(self): + def test_source_type_coerced_to_set(self): result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"source_type": " Remote "}) - assert result["source_type"] == "remote" + assert result["source_type"] == {"remote"} - def test_load_time_coerced_to_enum(self): + def test_load_time_coerced_to_enum_set(self): result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"load_time": "fast"}) - assert result["load_time"] == SeedDatasetLoadTime.FAST + assert result["load_time"] == {SeedDatasetLoadTime.FAST} + assert isinstance(result["load_time"], set) def test_load_time_normalized_strip_lower(self): result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"load_time": " Slow "}) - assert result["load_time"] == SeedDatasetLoadTime.SLOW + assert result["load_time"] == {SeedDatasetLoadTime.SLOW} - def test_modalities_list_coerced_lowercase(self): + def test_modalities_list_coerced_to_set(self): result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"modalities": ["Text", " IMAGE "]}) - assert result["modalities"] == ["text", "image"] + assert result["modalities"] == {"text", "image"} - def test_modalities_string_coerced_to_list(self): + def test_modalities_string_coerced_to_set(self): result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"modalities": "text"}) - assert result["modalities"] == ["text"] + assert result["modalities"] == {"text"} - def test_harm_categories_list_coerced_lowercase(self): + def test_harm_categories_list_coerced_to_set(self): result = SeedDatasetMetadata._coerce_metadata_values( raw_metadata={"harm_categories": ["Violence", " Cybercrime "]} ) - assert result["harm_categories"] == ["violence", "cybercrime"] + assert result["harm_categories"] == {"violence", "cybercrime"} - def test_harm_categories_string_coerced_to_list(self): + def test_harm_categories_string_coerced_to_set(self): result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"harm_categories": "violence"}) - assert result["harm_categories"] == ["violence"] + assert result["harm_categories"] == {"violence"} def test_unknown_type_skipped_with_warning(self, caplog): - """Unexpected types are dropped and logged, not passed through.""" result = SeedDatasetMetadata._coerce_metadata_values(raw_metadata={"tags": 12345}) assert "tags" not in result assert "Skipping metadata field" in caplog.text +class TestSingularFieldValidation: + """Test that singular fields (size, source_type) are validated.""" + + def test_singular_size_passes(self): + metadata = SeedDatasetMetadata(size={"large"}) + SeedDatasetMetadata._validate_singular_fields(metadata=metadata) + + def test_singular_source_type_passes(self): + metadata = SeedDatasetMetadata(source_type={"remote"}) + SeedDatasetMetadata._validate_singular_fields(metadata=metadata) + + def test_multiple_sizes_fails(self): + metadata = SeedDatasetMetadata(size={"small", "large"}) + with pytest.raises(ValueError, match="size"): + SeedDatasetMetadata._validate_singular_fields(metadata=metadata) + + def test_multiple_source_types_fails(self): + metadata = SeedDatasetMetadata(source_type={"remote", "local"}) + with pytest.raises(ValueError, match="source_type"): + SeedDatasetMetadata._validate_singular_fields(metadata=metadata) + + def test_none_fields_pass(self): + metadata = SeedDatasetMetadata() + SeedDatasetMetadata._validate_singular_fields(metadata=metadata) + + def test_multi_value_non_singular_fields_pass(self): + """Tags, modalities, harm_categories can have multiple values.""" + metadata = SeedDatasetMetadata( + tags={"safety", "default"}, + modalities={"text", "image"}, + harm_categories={"violence", "cybercrime"}, + ) + SeedDatasetMetadata._validate_singular_fields(metadata=metadata) + + class TestFilterProperties: - """ - Test that the filter fields populate correctly. - """ + """Test that the filter fields populate correctly via flat kwargs.""" def test_sizes_values(self): - f = SeedDatasetFilter(sizes=["small", "large"]) - assert "small" in f.sizes - assert "large" in f.sizes + f = SeedDatasetFilter(size={"small", "large"}) + assert "small" in f.criteria[0].size + assert "large" in f.criteria[0].size def test_load_times_values(self): - f = SeedDatasetFilter(load_times=[SeedDatasetLoadTime.FAST, SeedDatasetLoadTime.SLOW]) - assert SeedDatasetLoadTime.FAST in f.load_times - assert SeedDatasetLoadTime.SLOW in f.load_times + f = SeedDatasetFilter(load_time={SeedDatasetLoadTime.FAST, SeedDatasetLoadTime.SLOW}) + assert SeedDatasetLoadTime.FAST in f.criteria[0].load_time def test_sources_values(self): - f = SeedDatasetFilter(source_types=["local", "remote"]) - assert "local" in f.source_types - assert "remote" in f.source_types + f = SeedDatasetFilter(source_type={"local", "remote"}) + assert "local" in f.criteria[0].source_type def test_modalities_values(self): - f = SeedDatasetFilter(modalities=["text", "image"]) - assert "text" in f.modalities - assert "image" in f.modalities + f = SeedDatasetFilter(modalities={"text", "image"}) + assert "text" in f.criteria[0].modalities def test_tags_values(self): f = SeedDatasetFilter(tags={"safety", "default"}) - assert "safety" in f.tags - assert "default" in f.tags + assert "safety" in f.criteria[0].tags def test_harm_categories_values(self): - f = SeedDatasetFilter(harm_categories=["violence", "cybercrime"]) - assert "violence" in f.harm_categories - assert "cybercrime" in f.harm_categories + f = SeedDatasetFilter(harm_categories={"violence", "cybercrime"}) + assert "violence" in f.criteria[0].harm_categories diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index a430020dcc..5b5714c556 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -266,13 +266,13 @@ async def test_parse_metadata_from_class_attrs(self): metadata = await loader._parse_metadata() assert metadata is not None assert metadata.tags == {"default", "safety"} - assert metadata.size == "large" - assert metadata.modalities == ["text"] - assert metadata.harm_categories == ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] + assert metadata.size == {"large"} + assert metadata.modalities == {"text"} + assert metadata.harm_categories == {"cybercrime", "illegal", "harmful", "chemical_biological", "harassment"} # source_type is not declared as a class attribute on HarmBench; # load_time inherits the UNINITIALIZED default from SeedDatasetProvider base class assert metadata.source_type is None - assert metadata.load_time == SeedDatasetLoadTime.UNINITIALIZED + assert metadata.load_time == {SeedDatasetLoadTime.UNINITIALIZED} def test_all_tag(self): """Filter with tags={'all'} matches any metadata.""" @@ -292,62 +292,62 @@ def test_tags(self): def test_sizes(self): """Size filter checks membership in the sizes list.""" - metadata = SeedDatasetMetadata(size="large") + metadata = SeedDatasetMetadata(size={"large"}) assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(sizes=["large", "huge"]), + filters=SeedDatasetFilter(size={"large", "huge"}), ) assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(sizes=["small"]), + filters=SeedDatasetFilter(size={"small"}), ) def test_modalities(self): """Modality filter uses set intersection.""" - metadata = SeedDatasetMetadata(modalities=["text", "image"]) + metadata = SeedDatasetMetadata(modalities={"text", "image"}) assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(modalities=["text"]), + filters=SeedDatasetFilter(modalities={"text"}), ) assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(modalities=["audio"]), + filters=SeedDatasetFilter(modalities={"audio"}), ) def test_sources(self): """Source filter checks membership.""" - metadata = SeedDatasetMetadata(source_type="remote") + metadata = SeedDatasetMetadata(source_type={"remote"}) assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(source_types=["remote"]), + filters=SeedDatasetFilter(source_type={"remote"}), ) assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(source_types=["local"]), + filters=SeedDatasetFilter(source_type={"local"}), ) def test_ranks(self): """Load time filter checks membership.""" - metadata = SeedDatasetMetadata(load_time=SeedDatasetLoadTime.FAST) + metadata = SeedDatasetMetadata(load_time={SeedDatasetLoadTime.FAST}) assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(load_times=[SeedDatasetLoadTime.FAST]), + filters=SeedDatasetFilter(load_time={SeedDatasetLoadTime.FAST}), ) assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(load_times=[SeedDatasetLoadTime.SLOW]), + filters=SeedDatasetFilter(load_time={SeedDatasetLoadTime.SLOW}), ) def test_harm_categories(self): """Harm category filter uses set intersection.""" - metadata = SeedDatasetMetadata(harm_categories=["violence", "cybercrime"]) + metadata = SeedDatasetMetadata(harm_categories={"violence", "cybercrime"}) assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(harm_categories=["violence"]), + filters=SeedDatasetFilter(harm_categories={"violence"}), ) assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(harm_categories=["unrelated"]), + filters=SeedDatasetFilter(harm_categories={"unrelated"}), ) def test_empty_filter(self): @@ -392,38 +392,38 @@ def test_nonstrict_tags_partial_overlap_passes(self): def test_strict_harm_categories_all_present_matches(self): """strict_match requires ALL filter harm_categories present in metadata.""" - metadata = SeedDatasetMetadata(harm_categories=["violence", "cybercrime", "illegal"]) - filters = SeedDatasetFilter(harm_categories=["violence", "cybercrime"], strict_match=True) + metadata = SeedDatasetMetadata(harm_categories={"violence", "cybercrime", "illegal"}) + filters = SeedDatasetFilter(harm_categories={"violence", "cybercrime"}, strict_match=True) assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) def test_strict_harm_categories_partial_fails(self): """strict_match rejects if metadata is missing any requested harm category.""" - metadata = SeedDatasetMetadata(harm_categories=["violence"]) - filters = SeedDatasetFilter(harm_categories=["violence", "cybercrime"], strict_match=True) + metadata = SeedDatasetMetadata(harm_categories={"violence"}) + filters = SeedDatasetFilter(harm_categories={"violence", "cybercrime"}, strict_match=True) assert not SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) def test_strict_modalities_all_present_matches(self): """strict_match requires ALL filter modalities present in metadata.""" - metadata = SeedDatasetMetadata(modalities=["text", "image", "audio"]) - filters = SeedDatasetFilter(modalities=["text", "image"], strict_match=True) + metadata = SeedDatasetMetadata(modalities={"text", "image", "audio"}) + filters = SeedDatasetFilter(modalities={"text", "image"}, strict_match=True) assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) def test_strict_modalities_partial_fails(self): """strict_match rejects if metadata is missing any requested modality.""" - metadata = SeedDatasetMetadata(modalities=["text"]) - filters = SeedDatasetFilter(modalities=["text", "image"], strict_match=True) + metadata = SeedDatasetMetadata(modalities={"text"}) + filters = SeedDatasetFilter(modalities={"text", "image"}, strict_match=True) assert not SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) def test_strict_size_unchanged(self): """strict_match doesn't change size behavior — still membership check.""" - metadata = SeedDatasetMetadata(size="large") + metadata = SeedDatasetMetadata(size={"large"}) assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(sizes=["large"], strict_match=True), + filters=SeedDatasetFilter(size={"large"}, strict_match=True), ) assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(sizes=["small"], strict_match=True), + filters=SeedDatasetFilter(size={"small"}, strict_match=True), ) def test_strict_cross_axis_and(self): @@ -431,14 +431,14 @@ def test_strict_cross_axis_and(self): metadata = SeedDatasetMetadata( tags={"safety", "default"}, size="large", - harm_categories=["violence", "cybercrime"], + harm_categories={"violence", "cybercrime"}, ) # Both axes satisfied assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, filters=SeedDatasetFilter( tags={"safety"}, - harm_categories=["violence"], + harm_categories={"violence"}, strict_match=True, ), ) @@ -447,7 +447,7 @@ def test_strict_cross_axis_and(self): metadata=metadata, filters=SeedDatasetFilter( tags={"safety"}, - harm_categories=["violence", "illegal"], + harm_categories={"violence", "illegal"}, strict_match=True, ), ) @@ -511,7 +511,7 @@ def test_all_with_other_tags_warns(self, caplog): def test_all_with_other_fields_warns(self, caplog): """'all' combined with size/modality/etc logs a warning.""" - SeedDatasetFilter(tags={"all"}, sizes=["large"]) + SeedDatasetFilter(tags={"all"}, size={"large"}) assert "other fields will be ignored" in caplog.text def test_all_alone_no_warning(self, caplog): @@ -525,15 +525,15 @@ def test_all_bypasses_match_filter_entirely(self): metadata = SeedDatasetMetadata( tags={"unrelated"}, size="tiny", - modalities=["audio"], - harm_categories=["nothing"], + modalities={"audio"}, + harm_categories={"nothing"}, ) # Filter that would normally reject everything about this metadata filters = SeedDatasetFilter( tags={"all"}, - sizes=["huge"], - modalities=["text"], - harm_categories=["violence"], + size={"huge"}, + modalities={"text"}, + harm_categories={"violence"}, strict_match=True, ) assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @@ -609,7 +609,7 @@ async def test_parse_metadata_extracts_fields(self, tmp_path): loader = self._make_loader(yaml_path) metadata = await loader._parse_metadata() assert metadata is not None - assert metadata.harm_categories == ["violence"] + assert metadata.harm_categories == {"violence"} @pytest.mark.asyncio async def test_all_tag(self, tmp_path): @@ -673,7 +673,7 @@ async def test_sizes(self, tmp_path): loader = self._make_loader(yaml_path) metadata = await loader._parse_metadata() assert metadata is not None - filters = SeedDatasetFilter(sizes=["large"]) + filters = SeedDatasetFilter(size={"large"}) assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @pytest.mark.asyncio @@ -694,7 +694,7 @@ async def test_modalities(self, tmp_path): loader = self._make_loader(yaml_path) metadata = await loader._parse_metadata() assert metadata is not None - filters = SeedDatasetFilter(modalities=["text"]) + filters = SeedDatasetFilter(modalities={"text"}) assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @pytest.mark.asyncio @@ -714,7 +714,7 @@ async def test_sources(self, tmp_path): loader = self._make_loader(yaml_path) metadata = await loader._parse_metadata() assert metadata is not None - filters = SeedDatasetFilter(source_types=["remote"]) + filters = SeedDatasetFilter(source_type={"remote"}) assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @pytest.mark.asyncio @@ -734,7 +734,7 @@ async def test_ranks(self, tmp_path): loader = self._make_loader(yaml_path) metadata = await loader._parse_metadata() assert metadata is not None - filters = SeedDatasetFilter(load_times=[SeedDatasetLoadTime.FAST]) + filters = SeedDatasetFilter(load_time={SeedDatasetLoadTime.FAST}) assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @pytest.mark.asyncio @@ -756,7 +756,7 @@ async def test_harm_categories(self, tmp_path): loader = self._make_loader(yaml_path) metadata = await loader._parse_metadata() assert metadata is not None - filters = SeedDatasetFilter(harm_categories=["violence"]) + filters = SeedDatasetFilter(harm_categories={"violence"}) assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) @pytest.mark.asyncio @@ -858,11 +858,11 @@ def test_no_yaml_key_shadows_metadata_field_with_wrong_type(self, prompt_file): # Verify coerced types match expectations expected_types = { "tags": (set, type(None)), - "size": (str, type(None)), - "modalities": (list, type(None)), - "source_type": (str, type(None)), - "load_time": (SeedDatasetLoadTime, type(None)), - "harm_categories": (list, type(None)), + "size": (set, type(None)), + "modalities": (set, type(None)), + "source_type": (set, type(None)), + "load_time": (set, type(None)), + "harm_categories": (set, type(None)), } for key in overlapping_keys: value = getattr(metadata, key) From dd8a8a2749a717b74044b8292fef0e809a41dfbb Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 19 Mar 2026 16:30:06 +0000 Subject: [PATCH 24/29] . --- .../seed_datasets/seed_dataset_provider.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index e37f207666..7738cafcfa 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -140,11 +140,12 @@ async def get_all_dataset_names_async(cls, filters: Optional[SeedDatasetFilter] dataset_names.add(provider.dataset_name) except Exception as e: - raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e + raise ValueError( + f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) @classmethod - def _match_filter_to_metadata(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter) -> bool: + def _match_filter_to_metadata(cls, metadata: SeedDatasetMetadata, dataset_filter: SeedDatasetFilter) -> bool: """ Match a dataset's metadata against filter criteria. @@ -160,18 +161,19 @@ def _match_filter_to_metadata(cls, metadata: SeedDatasetMetadata, filters: SeedD Args: metadata: The dataset's metadata. - filters: The user-provided filter. + dataset_filter: The user-provided filter. Returns: Whether the metadata matches any criterion. """ # "all" always bypasses - if filters.has_all_tag: + if dataset_filter.has_all_tag: return True return any( - cls._match_single_criterion(metadata=metadata, criterion=c, strict_match=filters.strict_match) - for c in filters.criteria + cls._match_single_criterion( + metadata=metadata, criterion=c, strict_match=dataset_filter.strict_match) + for c in dataset_filter.criteria ) @classmethod @@ -263,9 +265,11 @@ async def fetch_datasets_async( # Validate dataset names if specified if dataset_names is not None: available_names = await cls.get_all_dataset_names_async() - invalid_names = [name for name in dataset_names if name not in available_names] + invalid_names = [ + name for name in dataset_names if name not in available_names] if invalid_names: - raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") + raise ValueError( + f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -291,7 +295,8 @@ async def fetch_single_dataset( # Progress tracking total_count = len(cls._registry) - pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") + pbar = tqdm(total=total_count, + desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -329,10 +334,12 @@ async def fetch_with_semaphore( logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] - combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) + combined_seeds = list( + existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset - logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") + logger.info( + f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values()) From a23075ebbce2ad235bc606ab29998f118ae1fe1d Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 19 Mar 2026 18:13:26 +0000 Subject: [PATCH 25/29] notebook --- doc/cookbooks/1_sending_prompts.ipynb | 522 +++++++++++++++----------- 1 file changed, 307 insertions(+), 215 deletions(-) diff --git a/doc/cookbooks/1_sending_prompts.ipynb b/doc/cookbooks/1_sending_prompts.ipynb index b5cfa70541..d9ab5163b8 100644 --- a/doc/cookbooks/1_sending_prompts.ipynb +++ b/doc/cookbooks/1_sending_prompts.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "0", + "id": "d2f6253d", "metadata": {}, "source": [ "# 1. Sending a Million Prompts\n", @@ -18,24 +18,47 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "1", - "metadata": {}, + "execution_count": 1, + "id": "ee1454e9", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-19T18:12:19.350618Z", + "iopub.status.busy": "2026-03-19T18:12:19.350069Z", + "iopub.status.idle": "2026-03-19T18:12:25.042805Z", + "shell.execute_reply": "2026-03-19T18:12:25.041039Z" + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Found default environment files: ['./.pyrit/.env', './.pyrit/.env.local']\n", - "Loaded environment file: ./.pyrit/.env\n", - "Loaded environment file: ./.pyrit/.env.local\n" + "Found default environment files: ['/home/vscode/.pyrit/.env']\n", + "Loaded environment file: /home/vscode/.pyrit/.env\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r", + "Loading datasets - this can take a few minutes: 0%| | 0/58 [00:00 Date: Thu, 19 Mar 2026 18:48:34 +0000 Subject: [PATCH 26/29] fixes --- .../seed_datasets/seed_dataset_provider.py | 2 +- pyrit/datasets/seed_datasets/seed_metadata.py | 26 +++--- .../datasets/test_seed_dataset_provider.py | 80 +++++++++---------- 3 files changed, 58 insertions(+), 50 deletions(-) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 7738cafcfa..5dd09607c9 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -135,7 +135,7 @@ async def get_all_dataset_names_async(cls, filters: Optional[SeedDatasetFilter] continue # Filters detected but no match -> don't add this dataset - if not cls._match_filter_to_metadata(metadata=metadata, filters=filters): + if not cls._match_filter_to_metadata(metadata=metadata, dataset_filter=filters): continue dataset_names.add(provider.dataset_name) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 369b8a0c74..d1738a0d7a 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -135,7 +135,8 @@ class SeedDatasetFilter: Passing both flat kwargs and criteria raises ValueError. Special tags: - - "all": Magic bypass. Returns every dataset, ignores all other fields. + - "all": Returns every dataset, ignores all other fields. This tag will + override anything else you pass to the filter object. - "default": Matches datasets with "default" in their tags. With strict_match=True, loses its shortcut and is treated as a normal tag. @@ -185,6 +186,21 @@ def __init__( else: self.criteria = [SeedDatasetMetadata()] + # Normalize tags: strip whitespace and lowercase so "ALL", " All ", etc. work + self.criteria = [ + SeedDatasetMetadata( + **{ + f.name: ( + {t.strip().lower() for t in getattr(c, f.name)} + if f.name == "tags" and getattr(c, f.name) is not None + else getattr(c, f.name) + ) + for f in fields(c) + } + ) + for c in self.criteria + ] + self.strict_match = strict_match self._validate() @@ -195,14 +211,6 @@ def _validate(self) -> None: all_criterion = next(c for c in self.criteria if c.tags and "all" in c.tags) - # Only lowercase "all" is accepted; reject "All", "ALL", etc. - non_lowercase = {t for t in all_criterion.tags if t != "all" and t.lower() == "all"} - if non_lowercase: - logger.warning( - "Filter has non-lowercase 'all' variants %s. Only lowercase 'all' is recognized as the bypass tag.", - non_lowercase, - ) - if all_criterion.tags and len(all_criterion.tags) > 1: logger.warning( "Filter has 'all' combined with other tags %s. " diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index 5b5714c556..9b116288fe 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -278,16 +278,16 @@ def test_all_tag(self): """Filter with tags={'all'} matches any metadata.""" metadata = SeedDatasetMetadata(tags={"safety"}) filters = SeedDatasetFilter(tags={"all"}) - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) def test_tags(self): """Tag filter uses set intersection.""" metadata = SeedDatasetMetadata(tags={"safety", "default"}) assert SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=SeedDatasetFilter(tags={"safety"}) + metadata=metadata, dataset_filter=SeedDatasetFilter(tags={"safety"}) ) assert not SeedDatasetProvider._match_filter_to_metadata( - metadata=metadata, filters=SeedDatasetFilter(tags={"unrelated"}) + metadata=metadata, dataset_filter=SeedDatasetFilter(tags={"unrelated"}) ) def test_sizes(self): @@ -295,11 +295,11 @@ def test_sizes(self): metadata = SeedDatasetMetadata(size={"large"}) assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(size={"large", "huge"}), + dataset_filter=SeedDatasetFilter(size={"large", "huge"}), ) assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(size={"small"}), + dataset_filter=SeedDatasetFilter(size={"small"}), ) def test_modalities(self): @@ -307,11 +307,11 @@ def test_modalities(self): metadata = SeedDatasetMetadata(modalities={"text", "image"}) assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(modalities={"text"}), + dataset_filter=SeedDatasetFilter(modalities={"text"}), ) assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(modalities={"audio"}), + dataset_filter=SeedDatasetFilter(modalities={"audio"}), ) def test_sources(self): @@ -319,11 +319,11 @@ def test_sources(self): metadata = SeedDatasetMetadata(source_type={"remote"}) assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(source_type={"remote"}), + dataset_filter=SeedDatasetFilter(source_type={"remote"}), ) assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(source_type={"local"}), + dataset_filter=SeedDatasetFilter(source_type={"local"}), ) def test_ranks(self): @@ -331,11 +331,11 @@ def test_ranks(self): metadata = SeedDatasetMetadata(load_time={SeedDatasetLoadTime.FAST}) assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(load_time={SeedDatasetLoadTime.FAST}), + dataset_filter=SeedDatasetFilter(load_time={SeedDatasetLoadTime.FAST}), ) assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(load_time={SeedDatasetLoadTime.SLOW}), + dataset_filter=SeedDatasetFilter(load_time={SeedDatasetLoadTime.SLOW}), ) def test_harm_categories(self): @@ -343,18 +343,18 @@ def test_harm_categories(self): metadata = SeedDatasetMetadata(harm_categories={"violence", "cybercrime"}) assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(harm_categories={"violence"}), + dataset_filter=SeedDatasetFilter(harm_categories={"violence"}), ) assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(harm_categories={"unrelated"}), + dataset_filter=SeedDatasetFilter(harm_categories={"unrelated"}), ) def test_empty_filter(self): """Empty filter (all None) matches any metadata.""" metadata = SeedDatasetMetadata(tags={"safety"}, size="large") filters = SeedDatasetFilter() - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) @pytest.mark.asyncio async def test_no_metadata(self): @@ -376,54 +376,54 @@ def test_strict_tags_all_present_matches(self): """strict_match requires ALL filter tags to be present in metadata.""" metadata = SeedDatasetMetadata(tags={"safety", "default", "curated"}) filters = SeedDatasetFilter(tags={"safety", "default"}, strict_match=True) - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) def test_strict_tags_partial_overlap_fails(self): """strict_match rejects if metadata is missing any requested tag.""" metadata = SeedDatasetMetadata(tags={"safety"}) filters = SeedDatasetFilter(tags={"safety", "default"}, strict_match=True) - assert not SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert not SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) def test_nonstrict_tags_partial_overlap_passes(self): """Without strict_match, any tag overlap is sufficient.""" metadata = SeedDatasetMetadata(tags={"safety"}) filters = SeedDatasetFilter(tags={"safety", "default"}, strict_match=False) - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) def test_strict_harm_categories_all_present_matches(self): """strict_match requires ALL filter harm_categories present in metadata.""" metadata = SeedDatasetMetadata(harm_categories={"violence", "cybercrime", "illegal"}) filters = SeedDatasetFilter(harm_categories={"violence", "cybercrime"}, strict_match=True) - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) def test_strict_harm_categories_partial_fails(self): """strict_match rejects if metadata is missing any requested harm category.""" metadata = SeedDatasetMetadata(harm_categories={"violence"}) filters = SeedDatasetFilter(harm_categories={"violence", "cybercrime"}, strict_match=True) - assert not SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert not SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) def test_strict_modalities_all_present_matches(self): """strict_match requires ALL filter modalities present in metadata.""" metadata = SeedDatasetMetadata(modalities={"text", "image", "audio"}) filters = SeedDatasetFilter(modalities={"text", "image"}, strict_match=True) - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) def test_strict_modalities_partial_fails(self): """strict_match rejects if metadata is missing any requested modality.""" metadata = SeedDatasetMetadata(modalities={"text"}) filters = SeedDatasetFilter(modalities={"text", "image"}, strict_match=True) - assert not SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert not SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) def test_strict_size_unchanged(self): """strict_match doesn't change size behavior — still membership check.""" metadata = SeedDatasetMetadata(size={"large"}) assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(size={"large"}, strict_match=True), + dataset_filter=SeedDatasetFilter(size={"large"}, strict_match=True), ) assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(size={"small"}, strict_match=True), + dataset_filter=SeedDatasetFilter(size={"small"}, strict_match=True), ) def test_strict_cross_axis_and(self): @@ -436,7 +436,7 @@ def test_strict_cross_axis_and(self): # Both axes satisfied assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter( + dataset_filter=SeedDatasetFilter( tags={"safety"}, harm_categories={"violence"}, strict_match=True, @@ -445,7 +445,7 @@ def test_strict_cross_axis_and(self): # harm_categories axis fails (missing "illegal") assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter( + dataset_filter=SeedDatasetFilter( tags={"safety"}, harm_categories={"violence", "illegal"}, strict_match=True, @@ -456,7 +456,7 @@ def test_strict_all_tag_still_bypasses(self): """tags={'all'} still bypasses everything even with strict_match.""" metadata = SeedDatasetMetadata(tags={"safety"}) filters = SeedDatasetFilter(tags={"all"}, strict_match=True) - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) def test_strict_default_plus_other_tags_requires_both(self): """With strict_match, 'default' is a normal tag — all must be present.""" @@ -464,12 +464,12 @@ def test_strict_default_plus_other_tags_requires_both(self): # Both present → match assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(tags={"default", "safety"}, strict_match=True), + dataset_filter=SeedDatasetFilter(tags={"default", "safety"}, strict_match=True), ) # Missing "curated" → reject assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(tags={"default", "safety", "curated"}, strict_match=True), + dataset_filter=SeedDatasetFilter(tags={"default", "safety", "curated"}, strict_match=True), ) def test_nonstrict_default_is_shortcut(self): @@ -478,7 +478,7 @@ def test_nonstrict_default_is_shortcut(self): metadata = SeedDatasetMetadata(tags={"default"}) assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(tags={"default", "nonexistent"}), + dataset_filter=SeedDatasetFilter(tags={"default", "nonexistent"}), ) def test_strict_default_without_tag_on_dataset_fails(self): @@ -487,12 +487,12 @@ def test_strict_default_without_tag_on_dataset_fails(self): # Without strict, "default" shortcut matches because metadata has "default" tag assert SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(tags={"default", "curated"}), + dataset_filter=SeedDatasetFilter(tags={"default", "curated"}), ) # With strict, ALL filter tags must be in metadata — "curated" is missing assert not SeedDatasetProvider._match_filter_to_metadata( metadata=metadata, - filters=SeedDatasetFilter(tags={"default", "curated"}, strict_match=True), + dataset_filter=SeedDatasetFilter(tags={"default", "curated"}, strict_match=True), ) @@ -536,7 +536,7 @@ def test_all_bypasses_match_filter_entirely(self): harm_categories={"violence"}, strict_match=True, ) - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) @pytest.mark.asyncio async def test_all_includes_datasets_without_metadata(self): @@ -632,7 +632,7 @@ async def test_all_tag(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(tags={"all"}) - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) @pytest.mark.asyncio async def test_tags(self, tmp_path): @@ -654,7 +654,7 @@ async def test_tags(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(tags={"safety"}) - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) @pytest.mark.asyncio async def test_sizes(self, tmp_path): @@ -674,7 +674,7 @@ async def test_sizes(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(size={"large"}) - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) @pytest.mark.asyncio async def test_modalities(self, tmp_path): @@ -695,7 +695,7 @@ async def test_modalities(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(modalities={"text"}) - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) @pytest.mark.asyncio async def test_sources(self, tmp_path): @@ -715,7 +715,7 @@ async def test_sources(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(source_type={"remote"}) - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) @pytest.mark.asyncio async def test_ranks(self, tmp_path): @@ -735,7 +735,7 @@ async def test_ranks(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(load_time={SeedDatasetLoadTime.FAST}) - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) @pytest.mark.asyncio async def test_harm_categories(self, tmp_path): @@ -757,7 +757,7 @@ async def test_harm_categories(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter(harm_categories={"violence"}) - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) @pytest.mark.asyncio async def test_empty_filter(self, tmp_path): @@ -778,7 +778,7 @@ async def test_empty_filter(self, tmp_path): metadata = await loader._parse_metadata() assert metadata is not None filters = SeedDatasetFilter() - assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, filters=filters) + assert SeedDatasetProvider._match_filter_to_metadata(metadata=metadata, dataset_filter=filters) @pytest.mark.asyncio async def test_no_metadata(self, tmp_path): From 47f9cf6435156007be1d51ef85da67efdc60df71 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 19 Mar 2026 19:03:48 +0000 Subject: [PATCH 27/29] . --- doc/cookbooks/1_sending_prompts.ipynb | 309 +++++++++++++------------- 1 file changed, 159 insertions(+), 150 deletions(-) diff --git a/doc/cookbooks/1_sending_prompts.ipynb b/doc/cookbooks/1_sending_prompts.ipynb index d9ab5163b8..58c84bd901 100644 --- a/doc/cookbooks/1_sending_prompts.ipynb +++ b/doc/cookbooks/1_sending_prompts.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "d2f6253d", + "id": "949de8e3", "metadata": {}, "source": [ "# 1. Sending a Million Prompts\n", @@ -19,13 +19,13 @@ { "cell_type": "code", "execution_count": 1, - "id": "ee1454e9", + "id": "89a72cfe", "metadata": { "execution": { - "iopub.execute_input": "2026-03-19T18:12:19.350618Z", - "iopub.status.busy": "2026-03-19T18:12:19.350069Z", - "iopub.status.idle": "2026-03-19T18:12:25.042805Z", - "shell.execute_reply": "2026-03-19T18:12:25.041039Z" + "iopub.execute_input": "2026-03-19T18:50:17.896840Z", + "iopub.status.busy": "2026-03-19T18:50:17.896039Z", + "iopub.status.idle": "2026-03-19T18:50:23.351147Z", + "shell.execute_reply": "2026-03-19T18:50:23.349871Z" } }, "outputs": [ @@ -42,7 +42,7 @@ "output_type": "stream", "text": [ "\r", - "Loading datasets - this can take a few minutes: 0%| | 0/58 [00:00 Date: Thu, 19 Mar 2026 19:16:44 +0000 Subject: [PATCH 28/29] precommit --- doc/cookbooks/1_sending_prompts.ipynb | 101 ++++---------- .../seed_datasets/seed_dataset_provider.py | 21 +-- .../test_seed_dataset_provider_integration.py | 127 ++++++++++++++++++ 3 files changed, 161 insertions(+), 88 deletions(-) diff --git a/doc/cookbooks/1_sending_prompts.ipynb b/doc/cookbooks/1_sending_prompts.ipynb index 58c84bd901..34c4869053 100644 --- a/doc/cookbooks/1_sending_prompts.ipynb +++ b/doc/cookbooks/1_sending_prompts.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "949de8e3", + "id": "0", "metadata": {}, "source": [ "# 1. Sending a Million Prompts\n", @@ -18,23 +18,16 @@ }, { "cell_type": "code", - "execution_count": 1, - "id": "89a72cfe", - "metadata": { - "execution": { - "iopub.execute_input": "2026-03-19T18:50:17.896840Z", - "iopub.status.busy": "2026-03-19T18:50:17.896039Z", - "iopub.status.idle": "2026-03-19T18:50:23.351147Z", - "shell.execute_reply": "2026-03-19T18:50:23.349871Z" - } - }, + "execution_count": null, + "id": "1", + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Found default environment files: ['/home/vscode/.pyrit/.env']\n", - "Loaded environment file: /home/vscode/.pyrit/.env\n" + "Found default environment files: ['./.pyrit/.env']\n", + "Loaded environment file: ./.pyrit/.env\n" ] }, { @@ -103,7 +96,7 @@ }, { "cell_type": "markdown", - "id": "19c027b8", + "id": "2", "metadata": {}, "source": [ "## Sending Prompts\n", @@ -120,16 +113,9 @@ }, { "cell_type": "code", - "execution_count": 2, - "id": "82e74135", - "metadata": { - "execution": { - "iopub.execute_input": "2026-03-19T18:50:23.357639Z", - "iopub.status.busy": "2026-03-19T18:50:23.356659Z", - "iopub.status.idle": "2026-03-19T18:50:58.547964Z", - "shell.execute_reply": "2026-03-19T18:50:58.545958Z" - } - }, + "execution_count": null, + "id": "3", + "metadata": {}, "outputs": [ { "name": "stdout", @@ -549,7 +535,7 @@ }, { "cell_type": "markdown", - "id": "44bfa643", + "id": "4", "metadata": {}, "source": [ "## Analyzing and Re-Scoring the Results\n", @@ -561,16 +547,9 @@ }, { "cell_type": "code", - "execution_count": 3, - "id": "c943dda1", - "metadata": { - "execution": { - "iopub.execute_input": "2026-03-19T18:50:58.551439Z", - "iopub.status.busy": "2026-03-19T18:50:58.551153Z", - "iopub.status.idle": "2026-03-19T18:51:03.400451Z", - "shell.execute_reply": "2026-03-19T18:51:03.391606Z" - } - }, + "execution_count": null, + "id": "5", + "metadata": {}, "outputs": [ { "name": "stdout", @@ -622,7 +601,7 @@ }, { "cell_type": "markdown", - "id": "4d52bb64", + "id": "6", "metadata": {}, "source": [ "## Exporting Prompts\n", @@ -632,16 +611,9 @@ }, { "cell_type": "code", - "execution_count": 4, - "id": "c3d6f8d3", - "metadata": { - "execution": { - "iopub.execute_input": "2026-03-19T18:51:03.413088Z", - "iopub.status.busy": "2026-03-19T18:51:03.412535Z", - "iopub.status.idle": "2026-03-19T18:51:03.470953Z", - "shell.execute_reply": "2026-03-19T18:51:03.468723Z" - } - }, + "execution_count": null, + "id": "7", + "metadata": {}, "outputs": [ { "data": { @@ -649,7 +621,7 @@ "PosixPath('/workspace/dbdata/all_conversations.json')" ] }, - "execution_count": 4, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -662,7 +634,7 @@ }, { "cell_type": "markdown", - "id": "79f0aa18", + "id": "8", "metadata": {}, "source": [ "Some operators also like to work locally and then upload to a central DB. You can upload your prompts like this." @@ -670,16 +642,9 @@ }, { "cell_type": "code", - "execution_count": 5, - "id": "42d8297f", - "metadata": { - "execution": { - "iopub.execute_input": "2026-03-19T18:51:03.474799Z", - "iopub.status.busy": "2026-03-19T18:51:03.474234Z", - "iopub.status.idle": "2026-03-19T18:51:03.492755Z", - "shell.execute_reply": "2026-03-19T18:51:03.491093Z" - } - }, + "execution_count": null, + "id": "9", + "metadata": {}, "outputs": [], "source": [ "all_message_pieces = memory.get_message_pieces(labels=memory_labels)\n", @@ -692,7 +657,7 @@ }, { "cell_type": "markdown", - "id": "f25abc4a", + "id": "10", "metadata": {}, "source": [ "## Querying Attack Results by Labels\n", @@ -701,16 +666,9 @@ }, { "cell_type": "code", - "execution_count": 6, - "id": "f0d5da8c", - "metadata": { - "execution": { - "iopub.execute_input": "2026-03-19T18:51:03.495966Z", - "iopub.status.busy": "2026-03-19T18:51:03.495643Z", - "iopub.status.idle": "2026-03-19T18:51:03.581095Z", - "shell.execute_reply": "2026-03-19T18:51:03.579277Z" - } - }, + "execution_count": null, + "id": "11", + "metadata": {}, "outputs": [ { "name": "stdout", @@ -742,11 +700,6 @@ } ], "metadata": { - "kernelspec": { - "display_name": "pyrit (3.13.5)", - "language": "python", - "name": "python3" - }, "language_info": { "codemirror_mode": { "name": "ipython", diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 5dd09607c9..7ce5eb8c37 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -140,8 +140,7 @@ async def get_all_dataset_names_async(cls, filters: Optional[SeedDatasetFilter] dataset_names.add(provider.dataset_name) except Exception as e: - raise ValueError( - f"Could not get dataset name from {provider_class.__name__}: {e}") from e + raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) @classmethod @@ -171,8 +170,7 @@ def _match_filter_to_metadata(cls, metadata: SeedDatasetMetadata, dataset_filter return True return any( - cls._match_single_criterion( - metadata=metadata, criterion=c, strict_match=dataset_filter.strict_match) + cls._match_single_criterion(metadata=metadata, criterion=c, strict_match=dataset_filter.strict_match) for c in dataset_filter.criteria ) @@ -265,11 +263,9 @@ async def fetch_datasets_async( # Validate dataset names if specified if dataset_names is not None: available_names = await cls.get_all_dataset_names_async() - invalid_names = [ - name for name in dataset_names if name not in available_names] + invalid_names = [name for name in dataset_names if name not in available_names] if invalid_names: - raise ValueError( - f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") + raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -295,8 +291,7 @@ async def fetch_single_dataset( # Progress tracking total_count = len(cls._registry) - pbar = tqdm(total=total_count, - desc="Loading datasets - this can take a few minutes", unit="dataset") + pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -334,12 +329,10 @@ async def fetch_with_semaphore( logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] - combined_seeds = list( - existing_dataset.seeds) + list(dataset.seeds) + combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset - logger.info( - f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") + logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values()) diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index 438f213444..6ea7a3f7cd 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -590,3 +590,130 @@ def test_all_tag_with_mixed_metadata_and_bare_datasets(self, tmp_path): ) assert len(all_names) == 2 assert "bare_dataset" in all_names + + +class TestHarmbenchMetadataInScenario: + """ + Regression tests verifying that metadata changes don't break the dataset + loading flow for scenarios that use datasets with explicit metadata. + + RedTeamAgent is the primary scenario that uses harmbench — the only remote + dataset with class-level metadata (tags, size, modalities, harm_categories). + These tests verify the full pipeline: metadata parsing → dataset loading → + memory storage → scenario initialization. + """ + + @pytest.mark.asyncio + async def test_harmbench_metadata_parses_correctly(self): + """HarmBench's class-level metadata is correctly parsed into sets.""" + from pyrit.datasets.seed_datasets.remote.harmbench_dataset import _HarmBenchDataset + + loader = _HarmBenchDataset() + metadata = await loader._parse_metadata() + + assert metadata is not None + assert isinstance(metadata.tags, set) + assert "default" in metadata.tags + assert "safety" in metadata.tags + assert metadata.size == {"large"} + assert metadata.modalities == {"text"} + assert isinstance(metadata.harm_categories, set) + assert "cybercrime" in metadata.harm_categories + + @pytest.mark.asyncio + async def test_harmbench_discoverable_via_filter(self): + """HarmBench can be found via tag and harm_category filters.""" + names_by_safety = await SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert "harmbench" in names_by_safety + + names_by_harm = await SeedDatasetProvider.get_all_dataset_names_async( + filters=SeedDatasetFilter(harm_categories={"cybercrime"}), + ) + assert "harmbench" in names_by_harm + + @pytest.mark.asyncio + async def test_harmbench_loads_and_stores_in_memory(self): + """HarmBench can be fetched and stored in memory for scenario use.""" + from pyrit.memory import CentralMemory + from pyrit.setup import initialize_pyrit_async + + await initialize_pyrit_async(memory_db_type="InMemory") + + datasets = await SeedDatasetProvider.fetch_datasets_async( + dataset_names=["harmbench"], + ) + assert len(datasets) == 1 + assert datasets[0].dataset_name == "harmbench" + assert len(datasets[0].seeds) > 0 + + memory = CentralMemory.get_memory_instance() + await memory.add_seed_datasets_to_memory_async( + datasets=datasets, + added_by="test", + ) + + # Verify seeds are queryable from memory (this is what scenarios do) + seed_groups = memory.get_seed_groups(dataset_name="harmbench") + assert seed_groups is not None + assert len(list(seed_groups)) > 0 + + @pytest.mark.asyncio + async def test_red_team_agent_initializes_with_harmbench(self): + """ + RedTeamAgent can initialize with harmbench dataset loaded in memory. + + This is the critical regression test: if metadata changes break the + parsing/coercion/filtering pipeline, this test will fail during + scenario initialization when it tries to load seed groups from memory. + """ + from unittest.mock import MagicMock + + from pyrit.executor.attack.core.attack_config import AttackScoringConfig + from pyrit.memory import CentralMemory + from pyrit.prompt_target import TextTarget + from pyrit.scenario.scenarios.foundry.red_team_agent import ( + FoundryStrategy, + RedTeamAgent, + ) + from pyrit.score.true_false.true_false_scorer import TrueFalseScorer + from pyrit.setup import initialize_pyrit_async + + await initialize_pyrit_async(memory_db_type="InMemory") + + # Load harmbench into memory + datasets = await SeedDatasetProvider.fetch_datasets_async( + dataset_names=["harmbench"], + ) + memory = CentralMemory.get_memory_instance() + await memory.add_seed_datasets_to_memory_async( + datasets=datasets, + added_by="test", + ) + + # Mock scorer to avoid Azure dependency + mock_scorer = MagicMock(spec=TrueFalseScorer) + mock_scorer.get_identifier.return_value = {"__type__": "MockScorer"} + + target = TextTarget() + rta = RedTeamAgent( + adversarial_chat=target, + attack_scoring_config=AttackScoringConfig(objective_scorer=mock_scorer), + include_baseline=False, + ) + + # This is the critical call — it loads seed groups from memory + # and builds atomic attacks. If metadata broke the pipeline, + # this would raise ValueError about missing seed_groups. + await rta.initialize_async( + objective_target=target, + max_concurrency=1, + scenario_strategies=[FoundryStrategy.Base64], + ) + + # Verify the scenario got objectives from harmbench + attacks = rta._atomic_attacks + assert len(attacks) > 0 + for attack in attacks: + assert len(attack.objectives) > 0 From 17c7b1cd3215bed052143a06322c03cfaa2c6705 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 19 Mar 2026 19:42:17 +0000 Subject: [PATCH 29/29] . --- pyrit/datasets/seed_datasets/seed_metadata.py | 45 +++++++++++------ .../datasets/test_seed_dataset_metadata.py | 48 +++++++++++++++++++ 2 files changed, 79 insertions(+), 14 deletions(-) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index d1738a0d7a..bf481229da 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -187,25 +187,42 @@ def __init__( self.criteria = [SeedDatasetMetadata()] # Normalize tags: strip whitespace and lowercase so "ALL", " All ", etc. work - self.criteria = [ - SeedDatasetMetadata( - **{ - f.name: ( - {t.strip().lower() for t in getattr(c, f.name)} - if f.name == "tags" and getattr(c, f.name) is not None - else getattr(c, f.name) - ) - for f in fields(c) - } - ) - for c in self.criteria - ] + def _normalize_criterion(c: SeedDatasetMetadata) -> SeedDatasetMetadata: + normalized = { + f.name: ({t.strip().lower() for t in vals} if f.name == "tags" and vals is not None else vals) + for f, vals in zip(fields(c), (getattr(c, f.name) for f in fields(c)), strict=True) + } + return SeedDatasetMetadata(**normalized) + + self.criteria = [_normalize_criterion(c) for c in self.criteria] self.strict_match = strict_match self._validate() def _validate(self) -> None: - """Warn about contradictory filter configurations.""" + """ + Warn about contradictory filter configurations. + + Raises: + ValueError: If strict_match is True and any criterion has multiple + values for a singular field (size, source_type). + """ + # strict_match with multi-valued singular fields is logically impossible. + # A dataset can't be both "small" AND "large" — these are mutually exclusive. + if self.strict_match: + for criterion in self.criteria: + for field_name in SeedDatasetMetadata.SINGULAR_FIELDS: + value = getattr(criterion, field_name) + if value is not None and len(value) > 1: + raise ValueError( + f"strict_match=True with multiple values for '{field_name}' " + f"({value}) is logically impossible — a dataset can only have " + f"one {field_name}. Mutually exclusive fields: " + f"{SeedDatasetMetadata.SINGULAR_FIELDS}. " + f"Use strict_match=False for OR matching, " + f"or split into separate criteria." + ) + if not self.has_all_tag: return diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py index b5cec477c8..a5a1c01084 100644 --- a/tests/unit/datasets/test_seed_dataset_metadata.py +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -208,6 +208,54 @@ def test_multi_value_non_singular_fields_pass(self): SeedDatasetMetadata._validate_singular_fields(metadata=metadata) +class TestStrictMatchSingularFieldValidation: + """ + Test that strict_match rejects multi-valued singular fields. + + A dataset can't be both "small" AND "large" — these are mutually exclusive. + strict_match=True with size={"small", "large"} is logically impossible + and should raise ValueError at filter construction time. + """ + + def test_strict_multi_size_raises(self): + """strict_match with size={'small', 'large'} is impossible.""" + with pytest.raises(ValueError, match="logically impossible"): + SeedDatasetFilter(size={"small", "large"}, strict_match=True) + + def test_strict_multi_source_type_raises(self): + """strict_match with source_type={'remote', 'local'} is impossible.""" + with pytest.raises(ValueError, match="logically impossible"): + SeedDatasetFilter(source_type={"remote", "local"}, strict_match=True) + + def test_strict_single_size_ok(self): + """strict_match with single size value is fine.""" + f = SeedDatasetFilter(size={"large"}, strict_match=True) + assert f.criteria[0].size == {"large"} + + def test_nonstrict_multi_size_ok(self): + """Without strict_match, multiple sizes is OR and perfectly valid.""" + f = SeedDatasetFilter(size={"small", "large"}, strict_match=False) + assert len(f.criteria[0].size) == 2 + + def test_strict_multi_tags_ok(self): + """Tags are NOT singular — strict with multiple tags is valid (AND).""" + f = SeedDatasetFilter(tags={"safety", "default"}, strict_match=True) + assert len(f.criteria[0].tags) == 2 + + def test_strict_multi_harm_categories_ok(self): + """harm_categories are NOT singular — strict with multiple is valid.""" + f = SeedDatasetFilter(harm_categories={"violence", "cybercrime"}, strict_match=True) + assert len(f.criteria[0].harm_categories) == 2 + + def test_strict_criteria_list_multi_size_raises(self): + """strict_match validation also applies to criteria=[] construction.""" + with pytest.raises(ValueError, match="logically impossible"): + SeedDatasetFilter( + criteria=[SeedDatasetMetadata(size={"small", "large"})], + strict_match=True, + ) + + class TestFilterProperties: """Test that the filter fields populate correctly via flat kwargs."""