Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ on:
- '*'
- '!push-action/*'
paths:
- snakebids/**
- src/**
- tests/**
- scripts/**

jobs:
Expand Down
25 changes: 19 additions & 6 deletions src/snakebids/core/_querying.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

CompiledFilter: TypeAlias = "Mapping[str, Sequence[str | Query]]"

_VALID_FILTER_METHODS = frozenset(("get", "match", "search"))


class PostFilter:
"""Filters to apply after indexing, typically derived from the CLI.
Expand All @@ -28,6 +30,13 @@ def __init__(self):
self.inclusions: dict[str, Sequence[str] | str] = {}
self.exclusions: dict[str, Sequence[str] | str] = {}

def __eq__(self, other: object):
if not isinstance(other, self.__class__):
return False
return (
self.inclusions == other.inclusions and self.exclusions == other.exclusions
)

def add_filter(
self,
key: str,
Expand Down Expand Up @@ -63,13 +72,16 @@ def add_filter(
if exclusions is not None:
self.exclusions[key] = self._format_exclusions(exclusions)

def _format_exclusions(self, exclusions: Iterable[str] | str):
def _format_exclusions(self, exclusions: Iterable[str] | str) -> list[str]:
# if multiple items to exclude, combine with with item1|item2|...
hit = None
exclude_string = "|".join(
re.escape(label) for label in itx.always_iterable(exclusions)
(hit := re.escape(label)) for label in itx.always_iterable(exclusions)
)
if hit is None:
return []
# regex to exclude subjects
return [f"^((?!({exclude_string})$).*)$"]
return [f"(?!({exclude_string})$)"]


@attrs.define(slots=False)
Expand All @@ -94,7 +106,7 @@ class UnifiedFilter:
@classmethod
def from_filter_dict(
cls,
filters: Mapping[str, str | bool | Sequence[str | bool]],
filters: FilterMap,
postfilter: PostFilter | None = None,
) -> Self:
"""Patch together a UnifiedFilter based on a basic filter dict.
Expand Down Expand Up @@ -246,6 +258,7 @@ def get_matching_files(
raise PybidsError(msg) from err

if search is not None:
# intersection preserving order
return [p for p in get if p in search]
return get

Expand Down Expand Up @@ -338,11 +351,11 @@ def wrap(filt: str, method: str):
filt = raw_filter
filt_type = "get"
else:
if filt_type not in {"match", "search", "get"}:
if filt_type not in _VALID_FILTER_METHODS:
raise _InvalidKeyError(key, raw_filter)
if TYPE_CHECKING:
assert isinstance(raw_filter, dict)
filt = cast("str | bool | Sequence[str | bool]", raw_filter[filt_type])
filt = raw_filter[filt_type] # pyright: ignore[reportTypedDictNotRequiredAccess]

# these two must not be simultaneously true or false
if with_regex == (filt_type == "get"):
Expand Down
255 changes: 85 additions & 170 deletions tests/test_generate_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import functools as ft
import itertools as it
import keyword
import logging
import os
import re
import shutil
Expand All @@ -14,12 +15,12 @@
from collections import defaultdict
from collections.abc import Iterable
from pathlib import Path, PosixPath
from typing import Any, Literal, NamedTuple, TypedDict, TypeVar
from typing import Any, NamedTuple, TypeVar

import attrs
import more_itertools as itx
import pytest
from bids import BIDSLayout
from bids.layout import BIDSLayout
from hypothesis import HealthCheck, assume, example, given, settings
from hypothesis import strategies as st
from pytest_mock import MockerFixture
Expand All @@ -40,7 +41,7 @@
from snakebids.exceptions import ConfigError, PybidsError, RunError
from snakebids.paths._presets import bids
from snakebids.snakemake_compat import expand as sb_expand
from snakebids.types import InputsConfig
from snakebids.types import InputConfig, InputsConfig
from snakebids.utils.containers import MultiSelectDict
from snakebids.utils.utils import DEPRECATION_FLAG, BidsEntity, BidsParseError
from tests import strategies as sb_st
Expand Down Expand Up @@ -1687,197 +1688,111 @@ def test_generate_inputs(dataset: BidsDataset, tmpdir: Path):
assert reindexed.layout is not None


@st.composite
def dataset_with_subject(draw: st.DrawFn):
entities = draw(sb_st.bids_entity_lists(blacklist_entities=["subject"]))
entities += ["subject"]
return BidsDataset.from_iterable(
[
draw(
sb_st.bids_components(
whitelist_entities=entities,
min_entities=len(entities),
max_entities=len(entities),
restrict_patterns=True,
unique=True,
)
)
]
)
class _FakeBIDSLayout:
def __init__(self, *args: Any, **kwargs: Any):
pass

def get(self, regex_search: bool, **kwargs: Any) -> list[str]:
return []

class TestParticipantFiltering:
MODE = Literal["include", "exclude"]

def get_filter_params(self, mode: MODE, filters: list[str] | str):
class FiltParams(TypedDict, total=False):
participant_label: list[str] | str
exclude_participant_label: list[str] | str
def __eq__(self, other: object):
return isinstance(other, self.__class__)

if mode == "include":
return FiltParams({"participant_label": filters})
if mode == "exclude":
return FiltParams({"exclude_participant_label": filters})
msg = f"Invalid mode specification: {mode}"
raise ValueError(msg)

@given(
data=st.data(),
dataset=dataset_with_subject(),
)
@settings(
deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]
class TestParticipantFiltering:
@pytest.mark.parametrize(
("include", "exclude"),
[
("01", None),
(None, "01"),
("01", "02"),
(["01", "02"], "02"),
(None, ["02", "03"]),
(["01", "02"], ["02", "03"]),
],
)
def test_exclude_and_participant_label_filter_correctly(
self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path
):
root = tempfile.mkdtemp(dir=tmpdir)
rooted = BidsDataset.from_iterable(
attrs.evolve(comp, path=os.path.join(root, comp.path))
for comp in dataset.values()
)
sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"])
excluded = data.draw(st.lists(sampler, unique=True) | sampler | st.none())
included = data.draw(st.lists(sampler, unique=True) | sampler | st.none())
reindexed = reindex_dataset(
root, rooted, exclude_participant_label=excluded, participant_label=included
)
reindexed_subjects = set(itx.first(reindexed.values()).entities["subject"])
expected_subjects = set(itx.first(rooted.values()).entities["subject"])
if included is not None:
expected_subjects &= set(itx.always_iterable(included))
if excluded is not None:
expected_subjects -= set(itx.always_iterable(excluded))

assert reindexed_subjects == expected_subjects

@pytest.mark.parametrize("mode", ["include", "exclude"])
@given(
dataset=sb_st.datasets_one_comp(blacklist_entities=["subject"], unique=True),
participant_filter=st.lists(st.text(min_size=1)) | st.text(min_size=1),
)
@settings(
deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]
)
def test_participant_label_doesnt_filter_comps_without_subject(
self,
mode: MODE,
dataset: BidsDataset,
participant_filter: list[str] | str,
include: str | list[str] | None,
exclude: str | list[str] | None,
tmpdir: Path,
):
root = tempfile.mkdtemp(dir=tmpdir)
rooted = BidsDataset.from_iterable(
attrs.evolve(comp, path=os.path.join(root, comp.path))
for comp in dataset.values()
)
reindexed = reindex_dataset(
root, rooted, **self.get_filter_params(mode, participant_filter)
zip_lists = {
"subject": ["01", "02", "03"],
"acq": ["x", "y", "z"],
}
component = BidsComponent(
name="0", path=str(tmpdir / get_bids_path(zip_lists)), zip_lists=zip_lists
)
assert reindexed == rooted
dataset = BidsDataset.from_iterable([component])

@pytest.mark.parametrize("mode", ["include", "exclude"])
@given(
dataset=dataset_with_subject(),
participant_filter=st.lists(st.text(min_size=1)) | st.text(min_size=1),
)
@settings(
deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]
)
def test_participant_label_doesnt_filter_comps_when_subject_has_filter(
self,
mode: MODE,
dataset: BidsDataset,
participant_filter: list[str] | str,
tmpdir: Path,
):
root = tempfile.mkdtemp(dir=tmpdir)
rooted = BidsDataset.from_iterable(
attrs.evolve(comp, path=os.path.join(root, comp.path))
for comp in dataset.values()
)
create_dataset(Path("/"), rooted)
reindexed = generate_inputs(
root,
create_snakebids_config(rooted),
**self.get_filter_params(mode, participant_filter),
reindexed = reindex_dataset(
str(tmpdir),
dataset,
participant_label=include,
exclude_participant_label=exclude,
)
assert reindexed == rooted

@pytest.mark.parametrize("mode", ["include", "exclude"])
@given(
dataset=dataset_with_subject(),
participant_filter=st.lists(st.text(min_size=1)) | st.text(min_size=1),
data=st.data(),
)
@settings(
deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]
)
def test_participant_label_doesnt_filter_comps_when_subject_has_filter_no_wcard(
expected = set(component.entities["subject"])
expected -= set(itx.always_iterable(exclude))
if include is not None:
expected &= set(itx.always_iterable(include))
assert set(itx.first(reindexed.values()).entities["subject"]) == expected

@pytest.mark.parametrize("include", ["include", ["include"], None])
@pytest.mark.parametrize("exclude", ["exclude", ["exclude"], None])
@pytest.mark.parametrize("filters", ["filters", None])
def test_participant_flags_and_filters_merged(
self,
mode: MODE,
dataset: BidsDataset,
participant_filter: list[str] | str,
data: st.DataObject,
tmpdir: Path,
include: str | list[str] | None,
exclude: str | list[str] | None,
filters: str | None,
mocker: MockerFixture,
caplog: pytest.LogCaptureFixture,
):
root = tempfile.mkdtemp(dir=tmpdir)
rooted = BidsDataset.from_iterable(
attrs.evolve(comp, path=os.path.join(root, comp.path))
for comp in dataset.values()
)
subject = data.draw(
st.sampled_from(itx.first(rooted.values()).entities["subject"])
)
create_dataset(Path("/"), rooted)
config = create_snakebids_config(rooted)
for comp in config.values():
comp["filters"] = dict(comp.get("filters", {}))
comp["filters"]["subject"] = subject
reindexed = generate_inputs(
root,
create_snakebids_config(rooted),
**self.get_filter_params(mode, participant_filter),
)
assert reindexed == rooted
mocker.patch.object(input_generation, "BIDSLayout", _FakeBIDSLayout)
patch = mocker.patch.object(input_generation, "get_matching_files")
component: InputConfig = {"filters": {"subject": filters}} if filters else {}
with caplog.at_level(logging.ERROR, "snakebids.core.input_generation"):
generate_inputs(
"",
{"": component},
participant_label=include,
exclude_participant_label=exclude,
)
pf = PostFilter()
pf.add_filter("subject", include, exclude)
uf = UnifiedFilter(component, pf)
patch.assert_called_once_with(_FakeBIDSLayout(), uf)

@given(
data=st.data(),
dataset=dataset_with_subject().filter(
lambda ds: set(itx.first(ds.values()).wildcards) != {"subject", "extension"}
),
)
@settings(
deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]
)
def test_exclude_participant_does_not_make_all_other_filters_regex(
self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path
self, tmpdir: Path
):
root = tempfile.mkdtemp(dir=tmpdir)
rooted = BidsDataset.from_iterable(
attrs.evolve(comp, path=os.path.join(root, comp.path))
for comp in dataset.values()
# Construct a small deterministic BIDS dataset with a subject entity and
# at least one other entity so we can mutate that other entity's values.
# Single-component dataset with a valid BIDS-style path and entities.
mut_entity = "acq"
zip_lists = {
"subject": ["01", "02"],
mut_entity: ["x", "y"],
}
component = BidsComponent(
name="0", path=str(tmpdir / get_bids_path(zip_lists)), zip_lists=zip_lists
)
dataset = BidsDataset.from_iterable([component])

# Create an extra set of paths by modifying one of the existing components to
# put foo after a set of entity values. If that filter gets changed to a regex,
# all of the suffixed decoys will get picked up by pybids
ziplist = dict(itx.first(rooted.values()).zip_lists)
mut_entity = itx.first(
filter(lambda e: e not in {"subject", "extension"}, ziplist)
)
ziplist[mut_entity] = ["foo" + v for v in ziplist[mut_entity]]
for path in sb_expand(itx.first(rooted.values()).path, zip, **ziplist):
# put 'foo' after a set of entity values. If that filter gets changed to a
# regex, all of the suffixed decoys will get picked up by pybids.
zip_lists[mut_entity] = ["foo" + v for v in zip_lists[mut_entity]]
for path in sb_expand(itx.first(dataset.values()).path, zip, **zip_lists):
p = Path(path)
p.parent.mkdir(parents=True, exist_ok=True)
p.touch()

sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"])
label = data.draw(st.lists(sampler, unique=True) | sampler)
reindexed = reindex_dataset(root, rooted, exclude_participant_label=label)
reindexed_subjects = set(itx.first(reindexed.values()).entities["subject"])
original_subjects = set(itx.first(rooted.values()).entities["subject"])
assert reindexed_subjects == original_subjects - set(itx.always_iterable(label))
reindexed = reindex_dataset(
str(tmpdir), dataset, exclude_participant_label="01"
)
assert itx.first(reindexed.values()) == component.filter(subject="02")


# The content of the dataset is irrelevant to this test, so one example suffices
Expand Down
Loading
Loading