diff --git a/.coveragerc b/.coveragerc index c67adeb..fb911b6 100644 --- a/.coveragerc +++ b/.coveragerc @@ -12,11 +12,14 @@ omit = */torch_concepts/data/datasets/mnist_arithmetic.py */torch_concepts/data/datasets/pendulum.py */torch_concepts/data/datasets/awa2.py + */torch_concepts/data/datasets/cub.py + # Exluding torch_concepts/data/datamodules/dataset_file.py */torch_concepts/data/datamodules/dsprites_regression.py */torch_concepts/data/datamodules/mnist_arithmetic.py */torch_concepts/data/datamodules/pendulum.py */torch_concepts/data/datamodules/awa2.py + */torch_concepts/data/datamodules/cub.py [report] exclude_lines = diff --git a/conceptarium/conf/dataset/cub.yaml b/conceptarium/conf/dataset/cub.yaml new file mode 100644 index 0000000..613dd37 --- /dev/null +++ b/conceptarium/conf/dataset/cub.yaml @@ -0,0 +1,25 @@ +defaults: + - _commons + - _self_ + +_target_: torch_concepts.data.datamodules.cub.CUBDataModule + +name: cub + +# Image resize size +image_size: 224 + +# backbone handling and embedding precomputation +backbone: resnet50 +precompute_embs: true +force_recompute: false + +# Task label - bird species (200 classes) +default_task_names: [class] + +# splitter - CUB has official train/test +splitter: + _target_: torch_concepts.data.splitters.native.NativeSplitter + +# Concept descriptions (optional; leave null to use raw attribute names) +label_descriptions: null diff --git a/conceptarium/conf/dataset/cub_incomplete.yaml b/conceptarium/conf/dataset/cub_incomplete.yaml new file mode 100644 index 0000000..540e3c2 --- /dev/null +++ b/conceptarium/conf/dataset/cub_incomplete.yaml @@ -0,0 +1,49 @@ +defaults: + - _commons + - _self_ + +_target_: torch_concepts.data.datamodules.cub.CUBDataModule + +name: cub + +# Image resize size +image_size: 224 + +# backbone handling and embedding precomputation +backbone: resnet50 +precompute_embs: true +force_recompute: false + +# Task label - bird species (200 classes) +default_task_names: [class] + +# splitter - CUB has official train/test +splitter: + _target_: torch_concepts.data.splitters.native.NativeSplitter + +# We generated the CUB incomplete dataset following the same procedure as in +# Zarlenga et al. (2024) "Avoiding Leakage Poisoning: Concept Interventions Under Distribution Shifts" (https://arxiv.org/pdf/2504.17921v1). +# More precisely, selecting the concepts belonging to the following groups: +# [“has_bill_shape”, “has_head_pattern”, “has_breast_colour”, “has_bill_length”, “has_wing_shape”, “has_tail_pattern”, “has_bill_color”] +concept_subset: [ + 'has_bill_shape::dagger', + 'has_bill_shape::hooked_seabird', + 'has_bill_shape::all-purpose', + 'has_bill_shape::cone', + 'has_head_pattern::eyebrow', + 'has_head_pattern::plain', + 'has_bill_length::about_the_same_as_head', + 'has_bill_length::shorter_than_head', + 'has_wing_shape::rounded-wings', + 'has_wing_shape::pointed-wings', + 'has_tail_pattern::solid', + 'has_tail_pattern::striped', + 'has_tail_pattern::multi-colored', + 'has_bill_color::grey', + 'has_bill_color::black', + 'has_bill_color::buff', + 'class', # task label +] + +# Concept descriptions (optional; leave null to use raw attribute names) +label_descriptions: null diff --git a/conceptarium/conf/sweep.yaml b/conceptarium/conf/sweep.yaml index 0843f12..440a4c0 100644 --- a/conceptarium/conf/sweep.yaml +++ b/conceptarium/conf/sweep.yaml @@ -9,7 +9,7 @@ hydra: # standard grid search params: seed: 42 - dataset: dag_asia, dag_sachs, dag_insurance + dataset: dag_asia, dag_sachs model: cbm, cem, c2bm model.train_inference._target_: torch_concepts.nn.DeterministicInference, diff --git a/tests/data/test_io.py b/tests/data/test_io.py index ddac399..3a92d08 100644 --- a/tests/data/test_io.py +++ b/tests/data/test_io.py @@ -14,6 +14,10 @@ save_pickle, load_pickle, download_url, + download_url_wget, + zip_is_valid, + wget_available, + DownloadProgressBar, ) @@ -151,3 +155,86 @@ def test_download_custom_filename(self): # Verify assert os.path.exists(path) assert os.path.basename(path) == custom_name + + +class TestZipIsValid: + """Test zip file validation.""" + + def test_valid_zip(self): + """zip_is_valid returns True for a well-formed zip.""" + with tempfile.TemporaryDirectory() as tmpdir: + zip_path = os.path.join(tmpdir, "good.zip") + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr("hello.txt", "hello world") + assert zip_is_valid(zip_path) is True + + def test_invalid_zip_bad_file(self): + """zip_is_valid returns False for a file that is not a zip.""" + with tempfile.TemporaryDirectory() as tmpdir: + bad_path = os.path.join(tmpdir, "bad.zip") + with open(bad_path, 'wb') as f: + f.write(b"this is not a zip file at all") + assert zip_is_valid(bad_path) is False + + def test_invalid_zip_truncated(self): + """zip_is_valid returns False for a truncated/corrupt zip.""" + with tempfile.TemporaryDirectory() as tmpdir: + zip_path = os.path.join(tmpdir, "truncated.zip") + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr("data.txt", "some data") + # Corrupt it by truncating + with open(zip_path, 'r+b') as f: + f.truncate(10) + assert zip_is_valid(zip_path) is False + + +class TestWgetAvailable: + """Test wget availability detection.""" + + def test_returns_bool(self): + """wget_available always returns a bool.""" + result = wget_available() + assert isinstance(result, bool) + + +class TestDownloadUrlWget: + """Test download_url_wget.""" + + def test_download_creates_file(self): + """download_url_wget downloads a small file successfully.""" + with tempfile.TemporaryDirectory() as tmpdir: + url = "https://raw.githubusercontent.com/pytorch/pytorch/main/README.md" + dest = os.path.join(tmpdir, "README.md") + download_url_wget(url, dest) + assert os.path.exists(dest) + assert os.path.getsize(dest) > 0 + + def test_download_resume(self): + """download_url_wget does not overwrite a pre-existing file of the same name.""" + with tempfile.TemporaryDirectory() as tmpdir: + url = "https://raw.githubusercontent.com/pytorch/pytorch/main/README.md" + dest = os.path.join(tmpdir, "README.md") + # First download + download_url_wget(url, dest) + size_first = os.path.getsize(dest) + # Second download (resume / skip) + download_url_wget(url, dest) + size_second = os.path.getsize(dest) + assert size_second >= size_first + + +class TestDownloadProgressBar: + """Test DownloadProgressBar.update_to.""" + + def test_update_to_sets_total(self): + """update_to sets self.total when tsize is provided.""" + with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, + desc="test", disable=True) as bar: + bar.update_to(b=1, bsize=1, tsize=1024) + assert bar.total == 1024 + + def test_update_to_without_tsize(self): + """update_to works without tsize (no total set).""" + with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, + desc="test", disable=True) as bar: + bar.update_to(b=2, bsize=512) # should not raise diff --git a/torch_concepts/data/__init__.py b/torch_concepts/data/__init__.py index 51e799e..ea8b82c 100644 --- a/torch_concepts/data/__init__.py +++ b/torch_concepts/data/__init__.py @@ -32,6 +32,7 @@ from .datasets.mnist_arithmetic import MNISTArithmeticDataset from .datasets.dsprites_regression import DSpritesRegressionDataset from .datasets.awa2 import AWA2Dataset +from .datasets.cub import CUBDataset # Re-export datamodules for convenient access from .datamodules.bnlearn import BnLearnDataModule @@ -42,6 +43,7 @@ from .datamodules.mnist_arithmetic import MNISTArithmeticDataModule from .datamodules.dsprites_regression import DSpritesRegressionDataModule from .datamodules.awa2 import AWA2DataModule +from .datamodules.cub import CUBDataModule __all__ = [ # Submodules @@ -66,7 +68,8 @@ "MNISTArithmeticDataset", "DSpritesRegressionDataset", "AWA2Dataset", - + "CUBDataset", + # DataModules "BnLearnDataModule", "ToyDAGDataModule", @@ -76,4 +79,5 @@ "MNISTArithmeticDataModule", "DSpritesRegressionDataModule", "AWA2DataModule", + "CUBDataModule", ] diff --git a/torch_concepts/data/datamodules/awa2.py b/torch_concepts/data/datamodules/awa2.py index eea580f..cb1482e 100644 --- a/torch_concepts/data/datamodules/awa2.py +++ b/torch_concepts/data/datamodules/awa2.py @@ -21,6 +21,8 @@ class AWA2DataModule(ConceptDataModule): Default: ``None`` (auto-creates ``./data/AWA2``). seed : int, optional Random seed for train / val / test split. Default: 42. + image_size : int, optional + Side length (px) to resize images to. Default: 224. val_size : float, optional Fraction of samples for validation. Default: 0.1. test_size : float, optional @@ -67,6 +69,7 @@ def __init__( self, root: str = None, seed: int = 42, + image_size: int = 224, val_size: float = 0.1, test_size: float = 0.2, splitter: Splitter = RandomSplitter(), @@ -83,6 +86,7 @@ def __init__( root=root, concept_subset=concept_subset, label_descriptions=label_descriptions, + image_size=image_size, ) super().__init__( diff --git a/torch_concepts/data/datamodules/cub.py b/torch_concepts/data/datamodules/cub.py new file mode 100644 index 0000000..29d3c00 --- /dev/null +++ b/torch_concepts/data/datamodules/cub.py @@ -0,0 +1,96 @@ +from ..datasets.cub import CUBDataset + +from ..base.datamodule import ConceptDataModule +from ...typing import BackboneType +from ..base.splitter import Splitter +from ..splitters.native import NativeSplitter + + +class CUBDataModule(ConceptDataModule): + """DataModule for CUB-200-2011 (Caltech-UCSD Birds). + + Handles data loading, splitting, and batching for the CUB-200-2011 dataset + with support for concept-based learning. CUB-200-2011 provides official + train / val / test splits via the Koh et al. pre-processed pickle files, + so :class:`~torch_concepts.data.splitters.NativeSplitter` is used by + default. + + .. note:: + CUB-200-2011 must be **manually downloaded** before use. + See :class:`~torch_concepts.data.datasets.CUBDataset` for instructions. + + Parameters + ---------- + root : str, optional + Root directory containing ``class_attr_data_10/`` and + ``CUB_200_2011/``. Default: ``None`` (auto-creates ``./data/CUB200``). + image_size : int, optional + Side length (px) to resize images to. Default: 224. + splitter : Splitter, optional + Splitting strategy. Default: ``NativeSplitter()`` (uses the official + train / val / test splits from the pickle files). + batch_size : int, optional + Number of samples per batch. Default: 512. + backbone : BackboneType, optional + Backbone model for feature extraction (e.g. ``'resnet50'``). + Default: ``None``. + precompute_embs : bool, optional + Whether to precompute and cache backbone embeddings. Default: ``True``. + force_recompute : bool, optional + Recompute embeddings even if a cache exists. Default: ``False``. + concept_subset : list of str, optional + Subset of concept names to retain. Default: ``None`` (all 113). + label_descriptions : dict, optional + Mapping from concept name to human-readable description. + workers : int, optional + Number of data-loading worker processes. Default: 0. + + Examples + -------- + >>> from torch_concepts.data import CUBDataModule + >>> + >>> dm = CUBDataModule( + ... root="./data/CUB200", + ... backbone="resnet50", + ... precompute_embs=True, + ... batch_size=64, + ... ) + >>> dm.setup() + >>> train_loader = dm.train_dataloader() + + See Also + -------- + CUBDataset : The underlying dataset class. + ConceptDataModule : Parent class with common datamodule functionality. + """ + + def __init__( + self, + root: str = None, + image_size: int = 224, + splitter: Splitter = NativeSplitter(), + batch_size: int = 512, + backbone: BackboneType = None, + precompute_embs: bool = True, + force_recompute: bool = False, + concept_subset: list | None = None, + label_descriptions: dict | None = None, + workers: int = 0, + **kwargs, + ): + dataset = CUBDataset( + root=root, + image_size=image_size, + concept_subset=concept_subset, + label_descriptions=label_descriptions, + ) + + super().__init__( + dataset=dataset, + batch_size=batch_size, + backbone=backbone, + precompute_embs=precompute_embs, + force_recompute=force_recompute, + workers=workers, + splitter=splitter, + ) diff --git a/torch_concepts/data/datasets/awa2.py b/torch_concepts/data/datasets/awa2.py index 0581ce9..bf137d9 100644 --- a/torch_concepts/data/datasets/awa2.py +++ b/torch_concepts/data/datasets/awa2.py @@ -18,12 +18,13 @@ import torchvision.transforms as transforms from PIL import Image from typing import List, Mapping, Optional -import urllib.request import zipfile from pathlib import Path from torch_concepts import Annotations, AxisAnnotation from torch_concepts.data.base import ConceptDataset +from torch_concepts.data.io import download_url_wget, zip_is_valid + logger = logging.getLogger(__name__) @@ -233,6 +234,8 @@ class AWA2Dataset(ConceptDataset): root : str, optional Root directory where the dataset is stored. Defaults to ``./data/AWA2``. + image_size : int, optional + Side length (px) to resize images to. Default: 224. concept_subset : list of str, optional Subset of concept names to retain. ``None`` keeps all 86. label_descriptions : dict, optional @@ -242,15 +245,14 @@ class AWA2Dataset(ConceptDataset): def __init__( self, root: str = None, - seed: int = 42, - train_size: float = 0.6, - val_size: float = 0.2, + image_size: int = 224, concept_subset: Optional[list] = None, label_descriptions: Optional[Mapping] = None, ): if root is None: root = os.path.join(os.getcwd(), 'data', 'AWA2') self.root = root + self.image_size = image_size self.label_descriptions = label_descriptions filenames, concepts, annotations, graph = self.load() @@ -297,59 +299,6 @@ def processed_filenames(self) -> List[str]: 'annotations.pt', ] - @staticmethod - def _zip_is_valid(path: str) -> bool: - """Return True if *path* is a structurally valid zip with correct CRCs.""" - try: - with zipfile.ZipFile(path) as z: - bad = z.testzip() # returns first bad filename or None - return bad is None - except zipfile.BadZipFile: - return False - - @staticmethod - def _wget_available() -> bool: - import shutil as _shutil - return _shutil.which("wget") is not None - - def _download_file(self, url: str, dest: str) -> None: - """Download *url* to *dest*. - - Uses ``wget --continue`` when available (handles large files and - network interruptions much better than urllib). Falls back to a - pure-Python streaming download with ``Range`` resume support. - """ - if self._wget_available(): - import subprocess - print(f"\nDownloading {os.path.basename(dest)} via wget ...") - subprocess.run( - [ - "wget", - "--continue", # resume partial downloads - "--tries=10", # retry up to 10 times on error - "--retry-connrefused", - "--waitretry=5", # wait 5 s between retries - "--show-progress", - "-O", dest, - url, - ], - check=True, - ) - else: - downloaded = os.path.getsize(dest) if os.path.exists(dest) else 0 - req = urllib.request.Request(url, headers={"Range": f"bytes={downloaded}-"}) - with urllib.request.urlopen(req) as r: - total = downloaded + int(r.headers.get("Content-Length", 0)) - print(f"\nDownloading {os.path.basename(dest)} ({total / 1e9:.2f} GB)") - with open(dest, "ab") as f: - while chunk := r.read(1024 * 1024): - f.write(chunk) - downloaded += len(chunk) - pct = downloaded / total * 100 - bar = "█" * int(pct // 2) + "░" * (50 - int(pct // 2)) - print(f"\r [{bar}] {pct:.1f}% {downloaded/1e9:.2f}/{total/1e9:.2f} GB", end="") - print() - def download(self): """Download raw AwA2 data from official sources. @@ -364,9 +313,9 @@ def download(self): for url in URLS: dest = os.path.join(self.root, url.split("/")[-1]) for attempt in range(1, _MAX_RETRIES + 1): - self._download_file(url, dest) + download_url_wget(url, dest) print(f" Verifying {os.path.basename(dest)} (attempt {attempt}/{_MAX_RETRIES}) ...") - if self._zip_is_valid(dest): + if zip_is_valid(dest): break print(f" CRC check failed — deleting corrupted file and retrying ...") os.remove(dest) @@ -498,7 +447,7 @@ def __getitem__(self, item: int) -> dict: img_path = self.input_data[item] x = Image.open(img_path) x = x.convert('RGB') # Ensure 3 channels - x = transforms.Resize((224, 224))(x) # Resize to 224x224 + x = transforms.Resize((self.image_size, self.image_size))(x) # Resize to 224x224 x = transforms.ToTensor()(x) # Convert to tensor and scale to [0, 1] c = self.concepts[item] return {'inputs': {'x': x}, 'concepts': {'c': c}} diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index 1b6636d..84c1322 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -1,47 +1,33 @@ """ -CUB-200 Dataset Loader -** THIS DATASET NEEDS TO BE DOWNLOADED BEFORE BEING ABLE TO USE THE LOADER ** +CUB-200-2011 (Caltech-UCSD Birds) Dataset -################################################################################ -## DOWNLOAD INSTRUCTIONS -################################################################################ - -**** OPTION #1 ***** -The simplest way to get the CUB dataset, is to download the pre-processed CUB -dataset by Koh et al. [CBM Paper]. This can be downloaded from their -public colab notebook at: https://worksheets.codalab.org/worksheets/0x362911581fcd4e048ddfd84f47203fd2. -You will need to download the original CUB dataset from that notebook (found -here: https://worksheets.codalab.org/bundles/0xd013a7ba2e88481bbc07e787f73109f5) -and the preprocessed "CUB_preprocessed" dataset (which can be directly accessed -here: https://worksheets.codalab.org/bundles/0x5b9d528d2101418b87212db92fea6683) - -**** OPTION #2 ***** -Follow the download the preprocess instructions found in Koh et al.'s original -repository here: https://github.com/yewsiang/ConceptBottleneck. -Specifically, here: https://github.com/yewsiang/ConceptBottleneck/blob/master/CUB/ - -################################################################################ - -[IMPORTANT] After downloading the files, they need to follow the following -structure: - - -This loader has been adapted/inspired by that of found in Koh et al.'s -repository (https://github.com/yewsiang/ConceptBottleneck/blob/master/CUB/cub_loader.py) -as well as in Espinosa Zarlenga and Barbiero's et al.'s repository -(https://github.com/mateoespinosa/cem). +Adapted from: + - Koh et al.'s paper Concept Bottleneck Models + - Espinosa Zarlenga and Barbiero et al.'s repository https://github.com/mateoespinosa/cem/blob/main/cem/data/CUB200/cub_loader.py. """ - -import numpy as np import os +import logging +from pathlib import Path +import tarfile +from anyio import Path import pickle +import numpy as np +import pandas as pd import torch import torchvision.transforms as transforms from collections import defaultdict from PIL import Image -from torch.utils.data import Dataset +from typing import List, Mapping, Optional +import zipfile +import shutil + +from torch_concepts import Annotations, AxisAnnotation +from torch_concepts.data.base import ConceptDataset +from torch_concepts.data.io import download_url + +logger = logging.getLogger(__name__) ######################################################## ## GENERAL DATASET GLOBAL VARIABLES @@ -49,13 +35,11 @@ N_CLASSES = 200 -# CAN BE OVERWRITTEN WITH AN ENV VARIABLE CUB_DIR -CUB_DIR = os.environ.get("CUB_DIR", './CUB200/') - - -######################################################### -## CONCEPT INFORMATION REGARDING CUB -######################################################### +URLS = [ + # NOTE: we retrieve the .pkl split files from the CEM repository since I cannot find the m in the CBM repo. + "https://raw.githubusercontent.com/mateoespinosa/cem/main/cem/data/CUB200/class_attr_data_10", + "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz?download=1", +] # CUB Class names @@ -701,188 +685,255 @@ )): group = concept_name[:concept_name.find("::")] CONCEPT_GROUP_MAP[group].append(i) +CONCEPT_GROUP_MAP = dict(CONCEPT_GROUP_MAP) +# Ordered names for the 112 selected concepts (matches order in pkl files) +SELECTED_CONCEPT_NAMES: List[str] = [CONCEPT_SEMANTICS[i] for i in SELECTED_CONCEPTS] -# Definitions from CUB (certainties.txt) -# 1 not visible -# 2 guessing -# 3 probably -# 4 definitely -# Unc map represents a mapping from the discrete score to a "mental probability" -DEFAULT_UNC_MAP = [ - {0: 0.5, 1: 0.5, 2: 0.5, 3:0.75, 4:1.0}, - {0: 0.5, 1: 0.5, 2: 0.5, 3:0.75, 4:1.0}, -] - -########################################################## -## Helper Functions -########################################################## - +class CUBDataset(ConceptDataset): + """Dataset class for CUB-200-2011 (Caltech-UCSD Birds). -def discrete_to_continuous_unc(unc_val, attr_label, unc_map): - ''' - Yield a continuous prob representing discrete conf val - Inspired by CBM data processing + CUB-200-2011 contains 11,788 bird images across 200 species classes, + annotated with 112 binary semantic attributes selected by Koh et al. + [CBM Paper] from the full set of 312 CUB attributes. - The selected probability should account for whether the concept is on or off - E.g., if a human is "probably" sure the concept is off - flip the prob in unc_map - ''' - unc_val = unc_val.item() - attr_label = attr_label.item() - return float(unc_map[int(attr_label)][unc_val]) + Official train / val / test splits from the pre-processed pickle files + are preserved; use :class:`~torch_concepts.data.splitters.NativeSplitter` + in the corresponding datamodule. + The concept vector per sample contains: -########################################################## -## Data Loaders -########################################################## + - columns 0-111: 112 binary semantic attributes (cardinality 1 each) + - column 112: bird species index 0-199 (cardinality 200) -class CUBDataset(Dataset): - """ - TODO + Parameters + ---------- + root : str, optional + Root directory that contains ``class_attr_data_10/`` and + ``CUB_200_2011/``. Defaults to ``./data/CUB200``. + image_size : int, optional + Side length (px) images are resized to. Defaults to 224. + concept_subset : list of str, optional + Subset of concept names to retain. ``None`` keeps all 113. + label_descriptions : dict, optional + Mapping from concept name to human-readable description. """ def __init__( self, - split='train', - uncertain_concept_labels=False, - root=CUB_DIR, - path_transform=None, - sample_transform=None, - concept_transform=None, - label_transform=None, - uncertainty_based_random_labels=False, - unc_map=DEFAULT_UNC_MAP, - selected_concepts=None, - training_augment=True, + root: str = None, + image_size: int = 224, + concept_subset: Optional[list] = None, + label_descriptions: Optional[Mapping] = None, ): - """ - TODO: Define different arguments - """ - if not (os.path.exists(root) and os.path.isdir(root)): - raise ValueError( - f'Provided CUB data directory "{root}" is not a valid or ' - f'an existing directory.' - ) - assert split in ['train', 'val', 'test'], ( - f"CUB split must be in ['train', 'val', 'test'] but got '{split}'" - ) - self.split = split - base_dir = os.path.join(root, 'class_attr_data_10') - self.pkl_file_path = os.path.join(base_dir, f'{split}.pkl') - self.name = 'CUB' - - self.data = [] - with open(self.pkl_file_path, 'rb') as f: - self.data.extend(pickle.load(f)) - image_size = 299 - if (split == 'train') and training_augment: - self.sample_transform = transforms.Compose([ - transforms.ColorJitter(brightness=32/255, saturation=(0.5, 1.5)), - transforms.RandomResizedCrop(image_size), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), #implicitly divides by 255 - transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [2, 2, 2]), - sample_transform or (lambda x: x), - ]) - else: - self.sample_transform = transforms.Compose([ - transforms.CenterCrop(image_size), - transforms.ToTensor(), #implicitly divides by 255 - transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [2, 2, 2]), - sample_transform or (lambda x: x), - ]) - self.concept_transform = concept_transform or (lambda x: x) - self.label_transform = label_transform or (lambda x: x) - self.uncertain_concept_labels = uncertain_concept_labels + if root is None: + root = os.path.join(os.getcwd(), 'data', 'CUB200') self.root = root - self.path_transform = path_transform - self.uncertainty_based_random_labels = uncertainty_based_random_labels - self.unc_map = unc_map - if selected_concepts is None: - selected_concepts = list(range(len(SELECTED_CONCEPTS))) - self.selected_concepts = selected_concepts - self.concept_names = self.concept_attr_names = list( - np.array( - CONCEPT_SEMANTICS - )[CONCEPT_SEMANTICS][selected_concepts] + self.image_size = image_size + self.label_descriptions = label_descriptions + + filenames, concepts, annotations, graph = self.load() + + super().__init__( + input_data=filenames, + concepts=concepts, + annotations=annotations, + graph=graph, + concept_names_subset=concept_subset, + name='CUBDataset', ) - self.task_names = self.task_attr_names = CLASS_NAMES - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - img_data = self.data[idx] - img_path = img_data['img_path'] - if self.path_transform is None: - # This is needed if the dataset is downloaded from the original - # CBM paper's repository/experiment code - img_path = img_path.replace( - '/juice/scr/scr102/scr/thaonguyen/CUB_supervision/datasets/', - self.root + + # ------------------------------------------------------------------ + # ConceptDataset interface + # ------------------------------------------------------------------ + + @property + def raw_filenames(self) -> List[str]: + return [ + "attributes", + "images", + "parts", + "attributes.txt", + "bounding_boxes.txt", + "classes.txt", + "image_class_labels.txt", + "images.txt", + "train_test_split.txt", # split with left out classes (not used in our setting) + "class_attr_data_10/train.pkl", # splits with all classes, from Koh et al.'s pre-processing + "class_attr_data_10/val.pkl", + "class_attr_data_10/test.pkl", + ] + + @property + def processed_filenames(self) -> List[str]: + return [ + 'filenames.txt', + 'concepts.pt', + 'annotations.pt', + 'split_mapping.h5', + ] + + def download(self) -> None: + """Downloads the CUB dataset if it is not already present.""" + if not os.path.exists(self.root): + os.makedirs(self.root) + + # store the Koh et al. pre-processed pickle files in a subfolder "class_attr_data_10" + class_attr_dir = os.path.join(self.root, "class_attr_data_10") + if not os.path.exists(class_attr_dir): + os.makedirs(class_attr_dir) + for split_name in ('train', 'val', 'test'): + url = f"{URLS[0]}/{split_name}.pkl" + download_url(url, class_attr_dir) + + tgz_path = download_url(URLS[1], self.root) + + with tarfile.open(tgz_path, "r:gz") as tar: + tar.extractall(path=self.root) + os.unlink(tgz_path) + + # Move all the files outside of the nested "CUB_200_2011" folder to the root + extracted_folder = os.path.join(self.root, "CUB_200_2011") + for item in os.listdir(extracted_folder): + src = os.path.join(extracted_folder, item) + dst = os.path.join(self.root, item) + if os.path.exists(dst): + if os.path.isdir(dst): + shutil.rmtree(dst) + else: + os.remove(dst) + shutil.move(src, dst) + os.rmdir(extracted_folder) + + def _remap_image_path(self, img_path: str) -> str: + """Remap the absolute path stored in a pkl entry to the local root. + + The Koh et al. pkl files embed absolute paths from their cluster + (``/juice/scr/.../datasets/``). We extract the ``CUB_200_2011/`` + subtree and join it with the local root. + """ + marker = 'CUB_200_2011' + idx = img_path.find(marker) + if idx >= 0: + relative = img_path[idx:] # e.g. "CUB_200_2011/images/.../file.jpg" + # Eliminate CUB_200_2011 from path + relative = relative[len(marker) + 1:] # e.g. "images/.../file.jpg" + return os.path.abspath(os.path.join(self.root, relative)) + # Fallback: replace the known cluster prefix + return img_path.replace( + '/juice/scr/scr102/scr/thaonguyen/CUB_supervision/datasets/', + self.root + os.sep, + ) + + def build(self): + """Process raw CUB pickle files and save cached dataset artefacts.""" + self.maybe_download() + + logger.info(f"Building CUB dataset from {self.root} ...") + + all_paths: List[str] = [] + all_attrs: List[List[int]] = [] + all_classes: List[int] = [] + split_labels: List[str] = [] + + for split_name in ('train', 'val', 'test'): + pkl_path = os.path.join( + self.root, 'class_attr_data_10', f'{split_name}.pkl' ) - try: - img = Image.open(img_path).convert('RGB') - except: - img_path_split = img_path.split('/') - img_path = '/'.join( - img_path_split[:2] + [self.split] + img_path_split[2:] - ) - img = Image.open(img_path).convert('RGB') - else: - img = Image.open(self.path_transform(img_path)).convert('RGB') + with open(pkl_path, 'rb') as fh: + entries = pickle.load(fh) + + for entry in entries: + img_path = self._remap_image_path(entry['img_path']) + all_paths.append(img_path) + all_attrs.append(entry['attribute_label']) # 112-dim list + all_classes.append(int(entry['class_label'])) + split_labels.append(split_name) + + n = len(all_paths) + logger.info(f"Loaded {n} samples (train/val/test)") + + # Build concept tensor: 112 binary attrs + class index + attr_array = np.array(all_attrs, dtype=np.float32) # (n, 112) + class_array = np.array(all_classes, dtype=np.float32).reshape(-1, 1) # (n, 1) + all_concepts_np = np.concatenate([attr_array, class_array], axis=1) # (n, 113) + concepts_tensor = torch.tensor(all_concepts_np, dtype=torch.float32) + + # Build Annotations + concept_names = SELECTED_CONCEPT_NAMES + ['class'] + binary_states = [['0'] for _ in SELECTED_CONCEPT_NAMES] + states = binary_states + [CLASS_NAMES] + cardinalities = [1] * len(SELECTED_CONCEPT_NAMES) + [N_CLASSES] + concept_metadata = {name: {'type': 'discrete'} for name in concept_names} + + annotations = Annotations({ + 1: AxisAnnotation( + labels=concept_names, + states=states, + cardinalities=cardinalities, + metadata=concept_metadata, + ) + }) + + # Build split mapping (native train/val/test) + split_series = pd.Series(split_labels, name='split') + + # Save artefacts + os.makedirs(self.root, exist_ok=True) + logger.info(f"Saving CUB dataset artefacts to {self.root}") + + with open(self.processed_paths[0], 'w') as fh: + fh.write('\n'.join(all_paths)) + + torch.save(concepts_tensor, self.processed_paths[1]) + torch.save(annotations, self.processed_paths[2]) + split_series.to_hdf(self.processed_paths[3], key='split_mapping') + + logger.info(f"CUB dataset saved ({n} samples)") - class_label = self.label_transform(img_data['class_label']) - img = self.sample_transform(img) + def load_raw(self): + """Load processed artefacts from disk.""" + self.maybe_build() - if self.uncertain_concept_labels: - attr_label = img_data['uncertain_attribute_label'] + logger.info(f"Loading CUB dataset from {self.root}") + + with open(self.processed_paths[0], 'r') as fh: + filenames = fh.read().strip().split('\n') + + concepts = torch.load(self.processed_paths[1], weights_only=False) + annotations = torch.load(self.processed_paths[2], weights_only=False) + graph = None + + return filenames, concepts, annotations, graph + + def load(self): + return self.load_raw() + + def __getitem__(self, item: int) -> dict: + if self.embs_precomputed: + x = self.input_data[item] else: - attr_label = img_data['attribute_label'] - attr_label = self.concept_transform( - np.array(attr_label)[self.selected_concepts] - ) + img_path = self.input_data[item] + x = Image.open(img_path).convert('RGB') + x = transforms.Resize((self.image_size, self.image_size))(x) + x = transforms.ToTensor()(x) + c = self.concepts[item] + return {'inputs': {'x': x}, 'concepts': {'c': c}} + + # ------------------------------------------------------------------ + # Properties — override base class which assumes input_data is a Tensor + # ------------------------------------------------------------------ + + @property + def n_samples(self) -> int: + return len(self.input_data) + + @property + def n_features(self) -> tuple: + return tuple(self[0]['inputs']['x'].shape) + + @property + def shape(self) -> tuple: + return (self.n_samples, *self.n_features) - # We may want to randomly sample concept labels based on their provided - # annotator uncertainty - if self.uncertainty_based_random_labels: - discrete_unc_label = np.array( - img_data['attribute_certainty'] - )[self.selected_concepts] - instance_attr_label = np.array(img_data['attribute_label']) - competencies = [] - for (discrete_unc_val, hard_concept_val) in zip( - discrete_unc_label, - instance_attr_label, - ): - competencies.append( - discrete_to_continuous_unc( - discrete_unc_val, - hard_concept_val, - self.unc_map, - ) - ) - attr_label = np.random.binomial(1, competencies) - - return img, torch.FloatTensor(attr_label), class_label - - def concept_weights(self): - """ - Calculate class imbalance ratio for binary attribute labels - """ - imbalance_ratio = [] - with open(self.pkl_file_path, 'rb') as f: - data = pickle.load(f) - n = len(data) - n_attr = len(data[0]['attribute_label']) - n_ones = [0] * n_attr - total = [n] * n_attr - for d in data: - labels = d['attribute_label'] - for i in range(n_attr): - n_ones[i] += labels[i] - for j in range(len(n_ones)): - imbalance_ratio.append(total[j]/n_ones[j] - 1) - return np.array(imbalance_ratio)[self.selected_concepts] \ No newline at end of file diff --git a/torch_concepts/data/io.py b/torch_concepts/data/io.py index 16a875a..57b58a0 100644 --- a/torch_concepts/data/io.py +++ b/torch_concepts/data/io.py @@ -104,9 +104,9 @@ def update_to(self, b=1, bsize=1, tsize=None): def download_url(url: str, - folder: str, - filename: Optional[str] = None, - verbose: bool = True): + folder: str, + filename: Optional[str] = None, + verbose: bool = True): r"""Downloads the content of an URL to a specific folder. Args: @@ -137,3 +137,57 @@ def download_url(url: str, disable=not verbose) as t: urllib.request.urlretrieve(url, filename=path, reporthook=t.update_to) return path + + +def zip_is_valid(path: str) -> bool: + """Return True if *path* is a structurally valid zip with correct CRCs.""" + try: + with zipfile.ZipFile(path) as z: + bad = z.testzip() # returns first bad filename or None + return bad is None + except zipfile.BadZipFile: + return False + + +def wget_available() -> bool: + import shutil as _shutil + return _shutil.which("wget") is not None + + +def download_url_wget(url: str, dest: str) -> None: + """Download *url* to *dest*. + + Uses ``wget --continue`` when available (handles large files and + network interruptions much better than urllib). Falls back to a + pure-Python streaming download with ``Range`` resume support. + """ + if wget_available(): + import subprocess + print(f"\nDownloading {os.path.basename(dest)} via wget ...") + subprocess.run( + [ + "wget", + "--continue", # resume partial downloads + "--tries=10", # retry up to 10 times on error + "--retry-connrefused", + "--waitretry=5", # wait 5 s between retries + "--show-progress", + "-O", dest, + url, + ], + check=True, + ) + else: + downloaded = os.path.getsize(dest) if os.path.exists(dest) else 0 + req = urllib.request.Request(url, headers={"Range": f"bytes={downloaded}-"}) + with urllib.request.urlopen(req) as r: + total = downloaded + int(r.headers.get("Content-Length", 0)) + print(f"\nDownloading {os.path.basename(dest)} ({total / 1e9:.2f} GB)") + with open(dest, "ab") as f: + while chunk := r.read(1024 * 1024): + f.write(chunk) + downloaded += len(chunk) + pct = downloaded / total * 100 + bar = "█" * int(pct // 2) + "░" * (50 - int(pct // 2)) + print(f"\r [{bar}] {pct:.1f}% {downloaded/1e9:.2f}/{total/1e9:.2f} GB", end="") + print() \ No newline at end of file