Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ omit =
*/torch_concepts/data/datasets/mnist_arithmetic.py
*/torch_concepts/data/datasets/pendulum.py
*/torch_concepts/data/datasets/awa2.py
*/torch_concepts/data/datasets/cub.py

# Exluding torch_concepts/data/datamodules/dataset_file.py
*/torch_concepts/data/datamodules/dsprites_regression.py
*/torch_concepts/data/datamodules/mnist_arithmetic.py
*/torch_concepts/data/datamodules/pendulum.py
*/torch_concepts/data/datamodules/awa2.py
*/torch_concepts/data/datamodules/cub.py

[report]
exclude_lines =
Expand Down
25 changes: 25 additions & 0 deletions conceptarium/conf/dataset/cub.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
defaults:
- _commons
- _self_

_target_: torch_concepts.data.datamodules.cub.CUBDataModule

name: cub

# Image resize size
image_size: 224

# backbone handling and embedding precomputation
backbone: resnet50
precompute_embs: true
force_recompute: false

# Task label - bird species (200 classes)
default_task_names: [class]

# splitter - CUB has official train/test
splitter:
_target_: torch_concepts.data.splitters.native.NativeSplitter

# Concept descriptions (optional; leave null to use raw attribute names)
label_descriptions: null
49 changes: 49 additions & 0 deletions conceptarium/conf/dataset/cub_incomplete.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
defaults:
- _commons
- _self_

_target_: torch_concepts.data.datamodules.cub.CUBDataModule

name: cub

# Image resize size
image_size: 224

# backbone handling and embedding precomputation
backbone: resnet50
precompute_embs: true
force_recompute: false

# Task label - bird species (200 classes)
default_task_names: [class]

# splitter - CUB has official train/test
splitter:
_target_: torch_concepts.data.splitters.native.NativeSplitter

# We generated the CUB incomplete dataset following the same procedure as in
# Zarlenga et al. (2024) "Avoiding Leakage Poisoning: Concept Interventions Under Distribution Shifts" (https://arxiv.org/pdf/2504.17921v1).
# More precisely, selecting the concepts belonging to the following groups:
# [“has_bill_shape”, “has_head_pattern”, “has_breast_colour”, “has_bill_length”, “has_wing_shape”, “has_tail_pattern”, “has_bill_color”]
concept_subset: [
'has_bill_shape::dagger',
'has_bill_shape::hooked_seabird',
'has_bill_shape::all-purpose',
'has_bill_shape::cone',
'has_head_pattern::eyebrow',
'has_head_pattern::plain',
'has_bill_length::about_the_same_as_head',
'has_bill_length::shorter_than_head',
'has_wing_shape::rounded-wings',
'has_wing_shape::pointed-wings',
'has_tail_pattern::solid',
'has_tail_pattern::striped',
'has_tail_pattern::multi-colored',
'has_bill_color::grey',
'has_bill_color::black',
'has_bill_color::buff',
'class', # task label
]

# Concept descriptions (optional; leave null to use raw attribute names)
label_descriptions: null
2 changes: 1 addition & 1 deletion conceptarium/conf/sweep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ hydra:
# standard grid search
params:
seed: 42
dataset: dag_asia, dag_sachs, dag_insurance
dataset: dag_asia, dag_sachs
model: cbm, cem, c2bm
model.train_inference._target_:
torch_concepts.nn.DeterministicInference,
Expand Down
87 changes: 87 additions & 0 deletions tests/data/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
save_pickle,
load_pickle,
download_url,
download_url_wget,
zip_is_valid,
wget_available,
DownloadProgressBar,
)


Expand Down Expand Up @@ -151,3 +155,86 @@ def test_download_custom_filename(self):
# Verify
assert os.path.exists(path)
assert os.path.basename(path) == custom_name


class TestZipIsValid:
"""Test zip file validation."""

def test_valid_zip(self):
"""zip_is_valid returns True for a well-formed zip."""
with tempfile.TemporaryDirectory() as tmpdir:
zip_path = os.path.join(tmpdir, "good.zip")
with zipfile.ZipFile(zip_path, 'w') as zf:
zf.writestr("hello.txt", "hello world")
assert zip_is_valid(zip_path) is True

def test_invalid_zip_bad_file(self):
"""zip_is_valid returns False for a file that is not a zip."""
with tempfile.TemporaryDirectory() as tmpdir:
bad_path = os.path.join(tmpdir, "bad.zip")
with open(bad_path, 'wb') as f:
f.write(b"this is not a zip file at all")
assert zip_is_valid(bad_path) is False

def test_invalid_zip_truncated(self):
"""zip_is_valid returns False for a truncated/corrupt zip."""
with tempfile.TemporaryDirectory() as tmpdir:
zip_path = os.path.join(tmpdir, "truncated.zip")
with zipfile.ZipFile(zip_path, 'w') as zf:
zf.writestr("data.txt", "some data")
# Corrupt it by truncating
with open(zip_path, 'r+b') as f:
f.truncate(10)
assert zip_is_valid(zip_path) is False


class TestWgetAvailable:
"""Test wget availability detection."""

def test_returns_bool(self):
"""wget_available always returns a bool."""
result = wget_available()
assert isinstance(result, bool)


class TestDownloadUrlWget:
"""Test download_url_wget."""

def test_download_creates_file(self):
"""download_url_wget downloads a small file successfully."""
with tempfile.TemporaryDirectory() as tmpdir:
url = "https://raw.githubusercontent.com/pytorch/pytorch/main/README.md"
dest = os.path.join(tmpdir, "README.md")
download_url_wget(url, dest)
assert os.path.exists(dest)
assert os.path.getsize(dest) > 0

def test_download_resume(self):
"""download_url_wget does not overwrite a pre-existing file of the same name."""
with tempfile.TemporaryDirectory() as tmpdir:
url = "https://raw.githubusercontent.com/pytorch/pytorch/main/README.md"
dest = os.path.join(tmpdir, "README.md")
# First download
download_url_wget(url, dest)
size_first = os.path.getsize(dest)
# Second download (resume / skip)
download_url_wget(url, dest)
size_second = os.path.getsize(dest)
assert size_second >= size_first


class TestDownloadProgressBar:
"""Test DownloadProgressBar.update_to."""

def test_update_to_sets_total(self):
"""update_to sets self.total when tsize is provided."""
with DownloadProgressBar(unit='B', unit_scale=True, miniters=1,
desc="test", disable=True) as bar:
bar.update_to(b=1, bsize=1, tsize=1024)
assert bar.total == 1024

def test_update_to_without_tsize(self):
"""update_to works without tsize (no total set)."""
with DownloadProgressBar(unit='B', unit_scale=True, miniters=1,
desc="test", disable=True) as bar:
bar.update_to(b=2, bsize=512) # should not raise
6 changes: 5 additions & 1 deletion torch_concepts/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .datasets.mnist_arithmetic import MNISTArithmeticDataset
from .datasets.dsprites_regression import DSpritesRegressionDataset
from .datasets.awa2 import AWA2Dataset
from .datasets.cub import CUBDataset

# Re-export datamodules for convenient access
from .datamodules.bnlearn import BnLearnDataModule
Expand All @@ -42,6 +43,7 @@
from .datamodules.mnist_arithmetic import MNISTArithmeticDataModule
from .datamodules.dsprites_regression import DSpritesRegressionDataModule
from .datamodules.awa2 import AWA2DataModule
from .datamodules.cub import CUBDataModule

__all__ = [
# Submodules
Expand All @@ -66,7 +68,8 @@
"MNISTArithmeticDataset",
"DSpritesRegressionDataset",
"AWA2Dataset",

"CUBDataset",

# DataModules
"BnLearnDataModule",
"ToyDAGDataModule",
Expand All @@ -76,4 +79,5 @@
"MNISTArithmeticDataModule",
"DSpritesRegressionDataModule",
"AWA2DataModule",
"CUBDataModule",
]
4 changes: 4 additions & 0 deletions torch_concepts/data/datamodules/awa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class AWA2DataModule(ConceptDataModule):
Default: ``None`` (auto-creates ``./data/AWA2``).
seed : int, optional
Random seed for train / val / test split. Default: 42.
image_size : int, optional
Side length (px) to resize images to. Default: 224.
val_size : float, optional
Fraction of samples for validation. Default: 0.1.
test_size : float, optional
Expand Down Expand Up @@ -67,6 +69,7 @@ def __init__(
self,
root: str = None,
seed: int = 42,
image_size: int = 224,
val_size: float = 0.1,
test_size: float = 0.2,
splitter: Splitter = RandomSplitter(),
Expand All @@ -83,6 +86,7 @@ def __init__(
root=root,
concept_subset=concept_subset,
label_descriptions=label_descriptions,
image_size=image_size,
)

super().__init__(
Expand Down
96 changes: 96 additions & 0 deletions torch_concepts/data/datamodules/cub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from ..datasets.cub import CUBDataset

from ..base.datamodule import ConceptDataModule
from ...typing import BackboneType
from ..base.splitter import Splitter
from ..splitters.native import NativeSplitter


class CUBDataModule(ConceptDataModule):
"""DataModule for CUB-200-2011 (Caltech-UCSD Birds).

Handles data loading, splitting, and batching for the CUB-200-2011 dataset
with support for concept-based learning. CUB-200-2011 provides official
train / val / test splits via the Koh et al. pre-processed pickle files,
so :class:`~torch_concepts.data.splitters.NativeSplitter` is used by
default.

.. note::
CUB-200-2011 must be **manually downloaded** before use.
See :class:`~torch_concepts.data.datasets.CUBDataset` for instructions.

Parameters
----------
root : str, optional
Root directory containing ``class_attr_data_10/`` and
``CUB_200_2011/``. Default: ``None`` (auto-creates ``./data/CUB200``).
image_size : int, optional
Side length (px) to resize images to. Default: 224.
splitter : Splitter, optional
Splitting strategy. Default: ``NativeSplitter()`` (uses the official
train / val / test splits from the pickle files).
batch_size : int, optional
Number of samples per batch. Default: 512.
backbone : BackboneType, optional
Backbone model for feature extraction (e.g. ``'resnet50'``).
Default: ``None``.
precompute_embs : bool, optional
Whether to precompute and cache backbone embeddings. Default: ``True``.
force_recompute : bool, optional
Recompute embeddings even if a cache exists. Default: ``False``.
concept_subset : list of str, optional
Subset of concept names to retain. Default: ``None`` (all 113).
label_descriptions : dict, optional
Mapping from concept name to human-readable description.
workers : int, optional
Number of data-loading worker processes. Default: 0.

Examples
--------
>>> from torch_concepts.data import CUBDataModule
>>>
>>> dm = CUBDataModule(
... root="./data/CUB200",
... backbone="resnet50",
... precompute_embs=True,
... batch_size=64,
... )
>>> dm.setup()
>>> train_loader = dm.train_dataloader()

See Also
--------
CUBDataset : The underlying dataset class.
ConceptDataModule : Parent class with common datamodule functionality.
"""

def __init__(
self,
root: str = None,
image_size: int = 224,
splitter: Splitter = NativeSplitter(),
batch_size: int = 512,
backbone: BackboneType = None,
precompute_embs: bool = True,
force_recompute: bool = False,
concept_subset: list | None = None,
label_descriptions: dict | None = None,
workers: int = 0,
**kwargs,
):
dataset = CUBDataset(
root=root,
image_size=image_size,
concept_subset=concept_subset,
label_descriptions=label_descriptions,
)

super().__init__(
dataset=dataset,
batch_size=batch_size,
backbone=backbone,
precompute_embs=precompute_embs,
force_recompute=force_recompute,
workers=workers,
splitter=splitter,
)
Loading
Loading