From 9729d304fd8df8f27b09a5d6ec0967a35cb02f67 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Tue, 25 Nov 2025 18:26:42 +0100 Subject: [PATCH 01/16] Update .gitignore --- .gitignore | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index ec98efae..3bc2a3ab 100644 --- a/.gitignore +++ b/.gitignore @@ -85,4 +85,8 @@ data/ !tests/data/ # conceptarium logs -outputs/ \ No newline at end of file +outputs/ + +CUB200/ + +.DS_Store \ No newline at end of file From 1f0a162bff984c2aec66cfe81fb8008656d618f0 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Tue, 25 Nov 2025 18:30:09 +0100 Subject: [PATCH 02/16] Fix cub -- no embeddings --- torch_concepts/data/datasets/cub.py | 704 ++++++---------------------- 1 file changed, 149 insertions(+), 555 deletions(-) diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index 1b6636d3..3574a0bb 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -1,381 +1,15 @@ -""" -CUB-200 Dataset Loader -** THIS DATASET NEEDS TO BE DOWNLOADED BEFORE BEING ABLE TO USE THE LOADER ** - -################################################################################ -## DOWNLOAD INSTRUCTIONS -################################################################################ - -**** OPTION #1 ***** -The simplest way to get the CUB dataset, is to download the pre-processed CUB -dataset by Koh et al. [CBM Paper]. This can be downloaded from their -public colab notebook at: https://worksheets.codalab.org/worksheets/0x362911581fcd4e048ddfd84f47203fd2. -You will need to download the original CUB dataset from that notebook (found -here: https://worksheets.codalab.org/bundles/0xd013a7ba2e88481bbc07e787f73109f5) -and the preprocessed "CUB_preprocessed" dataset (which can be directly accessed -here: https://worksheets.codalab.org/bundles/0x5b9d528d2101418b87212db92fea6683) - -**** OPTION #2 ***** -Follow the download the preprocess instructions found in Koh et al.'s original -repository here: https://github.com/yewsiang/ConceptBottleneck. -Specifically, here: https://github.com/yewsiang/ConceptBottleneck/blob/master/CUB/ - -################################################################################ - -[IMPORTANT] After downloading the files, they need to follow the following -structure: - - -This loader has been adapted/inspired by that of found in Koh et al.'s -repository (https://github.com/yewsiang/ConceptBottleneck/blob/master/CUB/cub_loader.py) -as well as in Espinosa Zarlenga and Barbiero's et al.'s repository -(https://github.com/mateoespinosa/cem). -""" - - -import numpy as np import os -import pickle +import tarfile import torch -import torchvision.transforms as transforms - -from collections import defaultdict +import pandas as pd +import numpy as np +from typing import List, Optional from PIL import Image -from torch.utils.data import Dataset - -######################################################## -## GENERAL DATASET GLOBAL VARIABLES -######################################################## - -N_CLASSES = 200 - -# CAN BE OVERWRITTEN WITH AN ENV VARIABLE CUB_DIR -CUB_DIR = os.environ.get("CUB_DIR", './CUB200/') - - -######################################################### -## CONCEPT INFORMATION REGARDING CUB -######################################################### - -# CUB Class names - -CLASS_NAMES = [ - "Black_footed_Albatross", - "Laysan_Albatross", - "Sooty_Albatross", - "Groove_billed_Ani", - "Crested_Auklet", - "Least_Auklet", - "Parakeet_Auklet", - "Rhinoceros_Auklet", - "Brewer_Blackbird", - "Red_winged_Blackbird", - "Rusty_Blackbird", - "Yellow_headed_Blackbird", - "Bobolink", - "Indigo_Bunting", - "Lazuli_Bunting", - "Painted_Bunting", - "Cardinal", - "Spotted_Catbird", - "Gray_Catbird", - "Yellow_breasted_Chat", - "Eastern_Towhee", - "Chuck_will_Widow", - "Brandt_Cormorant", - "Red_faced_Cormorant", - "Pelagic_Cormorant", - "Bronzed_Cowbird", - "Shiny_Cowbird", - "Brown_Creeper", - "American_Crow", - "Fish_Crow", - "Black_billed_Cuckoo", - "Mangrove_Cuckoo", - "Yellow_billed_Cuckoo", - "Gray_crowned_Rosy_Finch", - "Purple_Finch", - "Northern_Flicker", - "Acadian_Flycatcher", - "Great_Crested_Flycatcher", - "Least_Flycatcher", - "Olive_sided_Flycatcher", - "Scissor_tailed_Flycatcher", - "Vermilion_Flycatcher", - "Yellow_bellied_Flycatcher", - "Frigatebird", - "Northern_Fulmar", - "Gadwall", - "American_Goldfinch", - "European_Goldfinch", - "Boat_tailed_Grackle", - "Eared_Grebe", - "Horned_Grebe", - "Pied_billed_Grebe", - "Western_Grebe", - "Blue_Grosbeak", - "Evening_Grosbeak", - "Pine_Grosbeak", - "Rose_breasted_Grosbeak", - "Pigeon_Guillemot", - "California_Gull", - "Glaucous_winged_Gull", - "Heermann_Gull", - "Herring_Gull", - "Ivory_Gull", - "Ring_billed_Gull", - "Slaty_backed_Gull", - "Western_Gull", - "Anna_Hummingbird", - "Ruby_throated_Hummingbird", - "Rufous_Hummingbird", - "Green_Violetear", - "Long_tailed_Jaeger", - "Pomarine_Jaeger", - "Blue_Jay", - "Florida_Jay", - "Green_Jay", - "Dark_eyed_Junco", - "Tropical_Kingbird", - "Gray_Kingbird", - "Belted_Kingfisher", - "Green_Kingfisher", - "Pied_Kingfisher", - "Ringed_Kingfisher", - "White_breasted_Kingfisher", - "Red_legged_Kittiwake", - "Horned_Lark", - "Pacific_Loon", - "Mallard", - "Western_Meadowlark", - "Hooded_Merganser", - "Red_breasted_Merganser", - "Mockingbird", - "Nighthawk", - "Clark_Nutcracker", - "White_breasted_Nuthatch", - "Baltimore_Oriole", - "Hooded_Oriole", - "Orchard_Oriole", - "Scott_Oriole", - "Ovenbird", - "Brown_Pelican", - "White_Pelican", - "Western_Wood_Pewee", - "Sayornis", - "American_Pipit", - "Whip_poor_Will", - "Horned_Puffin", - "Common_Raven", - "White_necked_Raven", - "American_Redstart", - "Geococcyx", - "Loggerhead_Shrike", - "Great_Grey_Shrike", - "Baird_Sparrow", - "Black_throated_Sparrow", - "Brewer_Sparrow", - "Chipping_Sparrow", - "Clay_colored_Sparrow", - "House_Sparrow", - "Field_Sparrow", - "Fox_Sparrow", - "Grasshopper_Sparrow", - "Harris_Sparrow", - "Henslow_Sparrow", - "Le_Conte_Sparrow", - "Lincoln_Sparrow", - "Nelson_Sharp_tailed_Sparrow", - "Savannah_Sparrow", - "Seaside_Sparrow", - "Song_Sparrow", - "Tree_Sparrow", - "Vesper_Sparrow", - "White_crowned_Sparrow", - "White_throated_Sparrow", - "Cape_Glossy_Starling", - "Bank_Swallow", - "Barn_Swallow", - "Cliff_Swallow", - "Tree_Swallow", - "Scarlet_Tanager", - "Summer_Tanager", - "Artic_Tern", - "Black_Tern", - "Caspian_Tern", - "Common_Tern", - "Elegant_Tern", - "Forsters_Tern", - "Least_Tern", - "Green_tailed_Towhee", - "Brown_Thrasher", - "Sage_Thrasher", - "Black_capped_Vireo", - "Blue_headed_Vireo", - "Philadelphia_Vireo", - "Red_eyed_Vireo", - "Warbling_Vireo", - "White_eyed_Vireo", - "Yellow_throated_Vireo", - "Bay_breasted_Warbler", - "Black_and_white_Warbler", - "Black_throated_Blue_Warbler", - "Blue_winged_Warbler", - "Canada_Warbler", - "Cape_May_Warbler", - "Cerulean_Warbler", - "Chestnut_sided_Warbler", - "Golden_winged_Warbler", - "Hooded_Warbler", - "Kentucky_Warbler", - "Magnolia_Warbler", - "Mourning_Warbler", - "Myrtle_Warbler", - "Nashville_Warbler", - "Orange_crowned_Warbler", - "Palm_Warbler", - "Pine_Warbler", - "Prairie_Warbler", - "Prothonotary_Warbler", - "Swainson_Warbler", - "Tennessee_Warbler", - "Wilson_Warbler", - "Worm_eating_Warbler", - "Yellow_Warbler", - "Northern_Waterthrush", - "Louisiana_Waterthrush", - "Bohemian_Waxwing", - "Cedar_Waxwing", - "American_Three_toed_Woodpecker", - "Pileated_Woodpecker", - "Red_bellied_Woodpecker", - "Red_cockaded_Woodpecker", - "Red_headed_Woodpecker", - "Downy_Woodpecker", - "Bewick_Wren", - "Cactus_Wren", - "Carolina_Wren", - "House_Wren", - "Marsh_Wren", - "Rock_Wren", - "Winter_Wren", - "Common_Yellowthroat", -] -# Set of CUB attributes selected by Koh et al. [CBM Paper] -SELECTED_CONCEPTS = [ - 1, - 4, - 6, - 7, - 10, - 14, - 15, - 20, - 21, - 23, - 25, - 29, - 30, - 35, - 36, - 38, - 40, - 44, - 45, - 50, - 51, - 53, - 54, - 56, - 57, - 59, - 63, - 64, - 69, - 70, - 72, - 75, - 80, - 84, - 90, - 91, - 93, - 99, - 101, - 106, - 110, - 111, - 116, - 117, - 119, - 125, - 126, - 131, - 132, - 134, - 145, - 149, - 151, - 152, - 153, - 157, - 158, - 163, - 164, - 168, - 172, - 178, - 179, - 181, - 183, - 187, - 188, - 193, - 194, - 196, - 198, - 202, - 203, - 208, - 209, - 211, - 212, - 213, - 218, - 220, - 221, - 225, - 235, - 236, - 238, - 239, - 240, - 242, - 243, - 244, - 249, - 253, - 254, - 259, - 260, - 262, - 268, - 274, - 277, - 283, - 289, - 292, - 293, - 294, - 298, - 299, - 304, - 305, - 308, - 309, - 310, - 311, -] +import torchvision.transforms as T +from torch_concepts import Annotations +from torch_concepts.annotations import AxisAnnotation +from torch_concepts.data.base import ConceptDataset +from torch_concepts.data.io import download_url # Names of all CUB attributes CONCEPT_SEMANTICS = [ @@ -693,196 +327,156 @@ "has_wing_pattern::multi-colored", ] -# Generate a mapping containing all concept groups in CUB generated -# using a simple prefix tree -CONCEPT_GROUP_MAP = defaultdict(list) -for i, concept_name in enumerate(list( - np.array(CONCEPT_SEMANTICS)[SELECTED_CONCEPTS] -)): - group = concept_name[:concept_name.find("::")] - CONCEPT_GROUP_MAP[group].append(i) - - - -# Definitions from CUB (certainties.txt) -# 1 not visible -# 2 guessing -# 3 probably -# 4 definitely -# Unc map represents a mapping from the discrete score to a "mental probability" -DEFAULT_UNC_MAP = [ - {0: 0.5, 1: 0.5, 2: 0.5, 3:0.75, 4:1.0}, - {0: 0.5, 1: 0.5, 2: 0.5, 3:0.75, 4:1.0}, -] - -########################################################## -## Helper Functions -########################################################## - - -def discrete_to_continuous_unc(unc_val, attr_label, unc_map): - ''' - Yield a continuous prob representing discrete conf val - Inspired by CBM data processing - - The selected probability should account for whether the concept is on or off - E.g., if a human is "probably" sure the concept is off - flip the prob in unc_map - ''' - unc_val = unc_val.item() - attr_label = attr_label.item() - return float(unc_map[int(attr_label)][unc_val]) - - -########################################################## -## Data Loaders -########################################################## +CUB_DIR = os.environ.get("CUB_DIR", './CUB200/') -class CUBDataset(Dataset): +class CUB(ConceptDataset): """ - TODO + The CUB dataset is a dataset of bird images with annotated attributes. + Each image is associated with a set of concept labels (attributes) and + task labels (bird species). + + Attributes: + concept_attr_names: The names of the concept labels (attributes). + task_attr_names: The names of the task labels (bird species). + root: The root directory where the dataset is stored. + split: The dataset split to use ('train' or 'test'). + uncertain_concept_labels: Whether to treat uncertain concept labels as + positive. + path_transform: A function to transform the image paths. """ + name = "cub" + n_concepts = 312 + n_tasks = 200 + + concept_attr_names: List[str] = [] + task_attr_names: List[str] = [] def __init__( self, - split='train', - uncertain_concept_labels=False, - root=CUB_DIR, - path_transform=None, - sample_transform=None, - concept_transform=None, - label_transform=None, - uncertainty_based_random_labels=False, - unc_map=DEFAULT_UNC_MAP, - selected_concepts=None, - training_augment=True, - ): - """ - TODO: Define different arguments - """ - if not (os.path.exists(root) and os.path.isdir(root)): - raise ValueError( - f'Provided CUB data directory "{root}" is not a valid or ' - f'an existing directory.' - ) - assert split in ['train', 'val', 'test'], ( - f"CUB split must be in ['train', 'val', 'test'] but got '{split}'" - ) - self.split = split - base_dir = os.path.join(root, 'class_attr_data_10') - self.pkl_file_path = os.path.join(base_dir, f'{split}.pkl') - self.name = 'CUB' - - self.data = [] - with open(self.pkl_file_path, 'rb') as f: - self.data.extend(pickle.load(f)) - image_size = 299 - if (split == 'train') and training_augment: - self.sample_transform = transforms.Compose([ - transforms.ColorJitter(brightness=32/255, saturation=(0.5, 1.5)), - transforms.RandomResizedCrop(image_size), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), #implicitly divides by 255 - transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [2, 2, 2]), - sample_transform or (lambda x: x), - ]) - else: - self.sample_transform = transforms.Compose([ - transforms.CenterCrop(image_size), - transforms.ToTensor(), #implicitly divides by 255 - transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [2, 2, 2]), - sample_transform or (lambda x: x), - ]) - self.concept_transform = concept_transform or (lambda x: x) - self.label_transform = label_transform or (lambda x: x) - self.uncertain_concept_labels = uncertain_concept_labels + name : str = "cub", + precision : int = 32, + input_data : np.ndarray | pd.DataFrame | torch.Tensor = None, + concepts : np.ndarray | pd.DataFrame | torch.Tensor = None, + annotations : Annotations | None = None, + graph : pd.DataFrame | None = None, + concept_names_subset : List[str] | None = None, + root : str = CUB_DIR, + image_transform: Optional[object] = None, + ) -> None: self.root = root - self.path_transform = path_transform - self.uncertainty_based_random_labels = uncertainty_based_random_labels - self.unc_map = unc_map - if selected_concepts is None: - selected_concepts = list(range(len(SELECTED_CONCEPTS))) - self.selected_concepts = selected_concepts - self.concept_names = self.concept_attr_names = list( - np.array( - CONCEPT_SEMANTICS - )[CONCEPT_SEMANTICS][selected_concepts] + self.image_transform = image_transform or T.ToTensor() + + input_data, concepts, annotations, graph, image_paths = self.load() + + super().__init__( + name=name, + precision=precision, + input_data=input_data, + concepts=concepts, + annotations=annotations, + graph=graph, + concept_names_subset=concept_names_subset, ) - self.task_names = self.task_attr_names = CLASS_NAMES - - def __len__(self): - return len(self.data) + self.image_paths = image_paths + + @property + def raw_filenames(self) -> List[str]: + """List of raw filenames that need to be present in the raw directory + for the dataset to be considered present.""" + return [ + "CUB_200_2011/images.txt", + "CUB_200_2011/image_class_labels.txt", + "CUB_200_2011/train_test_split.txt", + "CUB_200_2011/bounding_boxes.txt", + "CUB_200_2011/classes.txt", + "CUB_200_2011/attributes/image_attribute_labels.txt", + "CUB_200_2011/attributes/class_attribute_labels_continuous.txt", + "CUB_200_2011/attributes/certainties.txt", + ] + + @property + def processed_filenames(self) -> List[str]: + """List of processed filenames that will be created during build step.""" + return [ + "cub_inputs.pt", + "cub_concepts.pt", + "cub_annotations.pt", + "cub_graph.h5", + ] + + def download(self) -> None: + """Downloads the CUB dataset if it is not already present.""" + if not os.path.exists(self.root): + os.makedirs(self.root) + + url = "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz?download=1" + tgz_path = download_url(url, self.root) + + with tarfile.open(tgz_path, "r:gz") as tar: + tar.extractall(path=self.root) + os.unlink(tgz_path) + + def build(self): + self.maybe_download() + + images = pd.read_csv(self.raw_paths[0], sep=r"\s+", header=None, names=['image_id', 'path']) + image_paths = images.set_index('image_id')['path'] + image_paths = image_paths.apply(lambda p: os.path.join(self.root, "CUB_200_2011", "images", p)) + + # attribute names: use canonical order from CONCEPT_SEMANTICS (matches attr_id 1..312) + concept_names = CONCEPT_SEMANTICS + + # image_attribute_labels.txt has 6 columns; we only need is_present (col 3) + attr_labels = pd.read_csv( + self.raw_paths[5], + header=None, + names=['image_id', 'attr_id', 'is_present', 'certainty', 'time_ms', 'extra'], + usecols=[0, 1, 2], + delim_whitespace=True, + engine="python", + ) + concepts_df = attr_labels.pivot(index='image_id', columns='attr_id', values='is_present').fillna(0) + concepts_df = concepts_df.loc[image_paths.index] + concepts_tensor = torch.tensor(concepts_df.values, dtype=torch.float32) + + concept_metadata = {name: {'type': 'discrete'} for name in concept_names} + cardinalities = tuple(1 for _ in concept_names) # binary concepts + annotations = Annotations({ + 1: AxisAnnotation(labels=concept_names, + cardinalities=cardinalities, + metadata=concept_metadata) + }) + + torch.save(list(image_paths.values), self.processed_paths[0]) + torch.save(concepts_tensor, self.processed_paths[1]) + torch.save(annotations, self.processed_paths[2]) + + def load_raw(self): + self.maybe_build() + # PyTorch 2.6 switches torch.load default to weights_only=True; set False to load metadata objects + image_paths = torch.load(self.processed_paths[0], weights_only=False) + concepts = torch.load(self.processed_paths[1], weights_only=False) + annotations = torch.load(self.processed_paths[2], weights_only=False) + return image_paths, concepts, annotations, None + + def load(self): + image_paths, concepts, annotations, graph = self.load_raw() + input_indices = torch.arange(len(image_paths), dtype=torch.long) + return input_indices, concepts, annotations, graph, image_paths def __getitem__(self, idx): - img_data = self.data[idx] - img_path = img_data['img_path'] - if self.path_transform is None: - # This is needed if the dataset is downloaded from the original - # CBM paper's repository/experiment code - img_path = img_path.replace( - '/juice/scr/scr102/scr/thaonguyen/CUB_supervision/datasets/', - self.root - ) - try: - img = Image.open(img_path).convert('RGB') - except: - img_path_split = img_path.split('/') - img_path = '/'.join( - img_path_split[:2] + [self.split] + img_path_split[2:] - ) - img = Image.open(img_path).convert('RGB') - else: - img = Image.open(self.path_transform(img_path)).convert('RGB') - - class_label = self.label_transform(img_data['class_label']) - img = self.sample_transform(img) - - if self.uncertain_concept_labels: - attr_label = img_data['uncertain_attribute_label'] - else: - attr_label = img_data['attribute_label'] - attr_label = self.concept_transform( - np.array(attr_label)[self.selected_concepts] - ) + img_path = self.image_paths[idx] + image = Image.open(img_path).convert("RGB") + if self.image_transform is not None: + image = self.image_transform(image) - # We may want to randomly sample concept labels based on their provided - # annotator uncertainty - if self.uncertainty_based_random_labels: - discrete_unc_label = np.array( - img_data['attribute_certainty'] - )[self.selected_concepts] - instance_attr_label = np.array(img_data['attribute_label']) - competencies = [] - for (discrete_unc_val, hard_concept_val) in zip( - discrete_unc_label, - instance_attr_label, - ): - competencies.append( - discrete_to_continuous_unc( - discrete_unc_val, - hard_concept_val, - self.unc_map, - ) - ) - attr_label = np.random.binomial(1, competencies) + concepts = self.concepts[idx] + sample = { + 'inputs': {'x': image}, + 'concepts': {'c': concepts}, + } + return sample - return img, torch.FloatTensor(attr_label), class_label - def concept_weights(self): - """ - Calculate class imbalance ratio for binary attribute labels - """ - imbalance_ratio = [] - with open(self.pkl_file_path, 'rb') as f: - data = pickle.load(f) - n = len(data) - n_attr = len(data[0]['attribute_label']) - n_ones = [0] * n_attr - total = [n] * n_attr - for d in data: - labels = d['attribute_label'] - for i in range(n_attr): - n_ones[i] += labels[i] - for j in range(len(n_ones)): - imbalance_ratio.append(total[j]/n_ones[j] - 1) - return np.array(imbalance_ratio)[self.selected_concepts] \ No newline at end of file +# test +cub = CUB() From 364356a53a2e7ab4f12c93949cd5bfacf6f08079 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Thu, 27 Nov 2025 12:51:11 +0100 Subject: [PATCH 03/16] Fix backbone + add embeddings computation in cub --- torch_concepts/data/backbone.py | 19 +++-- torch_concepts/data/base/dataset.py | 2 +- torch_concepts/data/datasets/cub.py | 114 +++++++++++++++++----------- 3 files changed, 84 insertions(+), 51 deletions(-) diff --git a/torch_concepts/data/backbone.py b/torch_concepts/data/backbone.py index 86ec3c5f..d73bb528 100644 --- a/torch_concepts/data/backbone.py +++ b/torch_concepts/data/backbone.py @@ -12,6 +12,18 @@ logger = logging.getLogger(__name__) +def _collate_inputs(batch): + """Collate only the input images, ignoring other fields.""" + first = batch[0] + if isinstance(first, dict): + if 'inputs' in first and isinstance(first['inputs'], dict) and 'x' in first['inputs']: + xs = [b['inputs']['x'] for b in batch] + else: + raise KeyError("Batch items must contain 'inputs'['x'].") + else: + xs = batch + return torch.stack(xs, dim=0) + def compute_backbone_embs( dataset, backbone: nn.Module, @@ -64,6 +76,7 @@ def compute_backbone_embs( batch_size=batch_size, shuffle=False, # Important: maintain order num_workers=workers, + collate_fn=_collate_inputs, ) embeddings_list = [] @@ -73,11 +86,7 @@ def compute_backbone_embs( with torch.no_grad(): iterator = tqdm(dataloader, desc="Extracting embeddings") if verbose else dataloader for batch in iterator: - # Handle both {'x': tensor} and {'inputs': {'x': tensor}} structures - if 'inputs' in batch: - x = batch['inputs']['x'].to(device) - else: - x = batch['x'].to(device) + x = batch.to(device) # batch already collated to only inputs embeddings = backbone(x) # Forward pass through backbone embeddings_list.append(embeddings.cpu()) # Move back to CPU and store diff --git a/torch_concepts/data/base/dataset.py b/torch_concepts/data/base/dataset.py index d67b6f85..43f9ed82 100644 --- a/torch_concepts/data/base/dataset.py +++ b/torch_concepts/data/base/dataset.py @@ -42,7 +42,7 @@ class ConceptDataset(Dataset): Args: input_data: Input features as numpy array, pandas DataFrame, or Tensor. concepts: Concept annotations as numpy array, pandas DataFrame, or Tensor. - annotations: Optional Annotations object with concept metadata. + annotations: Optional Annotations object with concept metadata. (TODO: this can't be optional, since we need concept names in set_concepts(.)) graph: Optional concept graph as pandas DataFrame or tensor. concept_names_subset: Optional list to select subset of concepts. precision: Numerical precision (16, 32, or 64, default: 32). diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index 3574a0bb..f673981e 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -3,13 +3,15 @@ import torch import pandas as pd import numpy as np -from typing import List, Optional -from PIL import Image +from typing import List, Dict +from PIL import Image, ImageFile import torchvision.transforms as T from torch_concepts import Annotations from torch_concepts.annotations import AxisAnnotation from torch_concepts.data.base import ConceptDataset from torch_concepts.data.io import download_url +from torch_concepts.data.backbone import compute_backbone_embs +from torchvision.models import resnet18 # Names of all CUB attributes CONCEPT_SEMANTICS = [ @@ -328,16 +330,15 @@ ] CUB_DIR = os.environ.get("CUB_DIR", './CUB200/') +ImageFile.LOAD_TRUNCATED_IMAGES = True -class CUB(ConceptDataset): +class CUBDataset(ConceptDataset): """ The CUB dataset is a dataset of bird images with annotated attributes. Each image is associated with a set of concept labels (attributes) and task labels (bird species). Attributes: - concept_attr_names: The names of the concept labels (attributes). - task_attr_names: The names of the task labels (bird species). root: The root directory where the dataset is stored. split: The dataset split to use ('train' or 'test'). uncertain_concept_labels: Whether to treat uncertain concept labels as @@ -348,36 +349,29 @@ class CUB(ConceptDataset): n_concepts = 312 n_tasks = 200 - concept_attr_names: List[str] = [] - task_attr_names: List[str] = [] - def __init__( self, - name : str = "cub", precision : int = 32, - input_data : np.ndarray | pd.DataFrame | torch.Tensor = None, concepts : np.ndarray | pd.DataFrame | torch.Tensor = None, annotations : Annotations | None = None, - graph : pd.DataFrame | None = None, concept_names_subset : List[str] | None = None, root : str = CUB_DIR, - image_transform: Optional[object] = None, + image_transform: object | None = None, ) -> None: self.root = root - self.image_transform = image_transform or T.ToTensor() - - input_data, concepts, annotations, graph, image_paths = self.load() - + # ensure images have consistent size for batching + self.image_transform = image_transform or T.Compose([T.Resize((256, 256)), T.ToTensor()]) + + embeddings, concepts, annotations, graph = self.load() + super().__init__( - name=name, precision=precision, - input_data=input_data, + input_data=embeddings, concepts=concepts, annotations=annotations, graph=graph, concept_names_subset=concept_names_subset, ) - self.image_paths = image_paths @property def raw_filenames(self) -> List[str]: @@ -398,10 +392,9 @@ def raw_filenames(self) -> List[str]: def processed_filenames(self) -> List[str]: """List of processed filenames that will be created during build step.""" return [ - "cub_inputs.pt", "cub_concepts.pt", "cub_annotations.pt", - "cub_graph.h5", + "cub_embeddings.pt", ] def download(self) -> None: @@ -415,18 +408,23 @@ def download(self) -> None: with tarfile.open(tgz_path, "r:gz") as tar: tar.extractall(path=self.root) os.unlink(tgz_path) - + def build(self): self.maybe_download() + + # workaround to get self.n_samples() work in ConceptDataset. We will overwrite later in super().__init__() + # create a torch tensor with shape (n_samples, whatever) and set self.input_data to it temporarily + temp_input_data = torch.zeros((11788, 10)) # CUB has 11788 samples + self.input_data = temp_input_data - images = pd.read_csv(self.raw_paths[0], sep=r"\s+", header=None, names=['image_id', 'path']) - image_paths = images.set_index('image_id')['path'] - image_paths = image_paths.apply(lambda p: os.path.join(self.root, "CUB_200_2011", "images", p)) - - # attribute names: use canonical order from CONCEPT_SEMANTICS (matches attr_id 1..312) + images = pd.read_csv( + self.raw_paths[0], + sep=r"\s+", + header=None, + names=["image_id", "path"], + ) concept_names = CONCEPT_SEMANTICS - # image_attribute_labels.txt has 6 columns; we only need is_present (col 3) attr_labels = pd.read_csv( self.raw_paths[5], header=None, @@ -436,7 +434,7 @@ def build(self): engine="python", ) concepts_df = attr_labels.pivot(index='image_id', columns='attr_id', values='is_present').fillna(0) - concepts_df = concepts_df.loc[image_paths.index] + concepts_df = concepts_df.loc[images["image_id"]] concepts_tensor = torch.tensor(concepts_df.values, dtype=torch.float32) concept_metadata = {name: {'type': 'discrete'} for name in concept_names} @@ -447,30 +445,55 @@ def build(self): metadata=concept_metadata) }) - torch.save(list(image_paths.values), self.processed_paths[0]) - torch.save(concepts_tensor, self.processed_paths[1]) - torch.save(annotations, self.processed_paths[2]) + torch.save(concepts_tensor, self.processed_paths[0]) + torch.save(annotations, self.processed_paths[1]) + + annotations = torch.load(self.processed_paths[1], weights_only=False) + self._annotations = annotations + self.maybe_reduce_annotations(annotations, None) + concepts = torch.load(self.processed_paths[0], weights_only=False) + # temporary placeholder so set_concepts has a length reference + self.input_data = torch.zeros((concepts.shape[0], 1)) + self.precision = 32 # set precision before calling set_concepts + self.set_concepts(concepts) + + # Compute embeddings using a pretrained model (e.g., ResNet) as backbone from torch_concepts.data.backbone + backbone = torch.nn.Sequential(*list(resnet18(pretrained=True).children())[:-1]) + embeddings = compute_backbone_embs( + self, + backbone, + batch_size=64, + workers=4, + verbose=True + ) + + torch.save(embeddings, self.processed_paths[2]) def load_raw(self): self.maybe_build() - # PyTorch 2.6 switches torch.load default to weights_only=True; set False to load metadata objects - image_paths = torch.load(self.processed_paths[0], weights_only=False) - concepts = torch.load(self.processed_paths[1], weights_only=False) - annotations = torch.load(self.processed_paths[2], weights_only=False) - return image_paths, concepts, annotations, None + concepts = torch.load(self.processed_paths[0], weights_only=False) + annotations = torch.load(self.processed_paths[1], weights_only=False) + embeddings = torch.load(self.processed_paths[2], weights_only=False) + return embeddings, concepts, annotations, None def load(self): - image_paths, concepts, annotations, graph = self.load_raw() - input_indices = torch.arange(len(image_paths), dtype=torch.long) - return input_indices, concepts, annotations, graph, image_paths + embeddings, concepts, annotations, graph = self.load_raw() + return embeddings, concepts, annotations, graph - def __getitem__(self, idx): - img_path = self.image_paths[idx] + def __getitem__(self, idx) -> Dict[str, Dict[str, torch.Tensor]]: + img_rel_path = pd.read_csv( # TODO: optimize by reading this once in __init__ + self.raw_paths[0], + header=None, + names=['image_id', 'img_path'], + delim_whitespace=True, + engine="python", + ).set_index('image_id').loc[idx + 1, 'img_path'] # idx +1 because image_id starts from 1 + img_path = os.path.join(self.root, "CUB_200_2011/images", img_rel_path) image = Image.open(img_path).convert("RGB") if self.image_transform is not None: image = self.image_transform(image) - concepts = self.concepts[idx] + concepts = self.concepts[idx].clone() sample = { 'inputs': {'x': image}, 'concepts': {'c': concepts}, @@ -478,5 +501,6 @@ def __getitem__(self, idx): return sample -# test -cub = CUB() +if __name__ == "__main__": + dataset = CUBDataset() + print(f"Dataset loaded with {dataset.n_samples} samples.") From 1b6ed3522752ece0c4b7ecced1b174bd71d0b133 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Thu, 27 Nov 2025 12:51:41 +0100 Subject: [PATCH 04/16] Remove test in cub --- torch_concepts/data/datasets/cub.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index f673981e..465a981c 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -499,8 +499,3 @@ def __getitem__(self, idx) -> Dict[str, Dict[str, torch.Tensor]]: 'concepts': {'c': concepts}, } return sample - - -if __name__ == "__main__": - dataset = CUBDataset() - print(f"Dataset loaded with {dataset.n_samples} samples.") From 6db134554bd1b2b118648510442af3a0add65ca3 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Thu, 27 Nov 2025 18:28:19 +0100 Subject: [PATCH 05/16] Add typing annotations --- torch_concepts/data/datasets/cub.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index 465a981c..eeeb5225 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -3,7 +3,7 @@ import torch import pandas as pd import numpy as np -from typing import List, Dict +from typing import List, Dict, Tuple from PIL import Image, ImageFile import torchvision.transforms as T from torch_concepts import Annotations @@ -409,7 +409,7 @@ def download(self) -> None: tar.extractall(path=self.root) os.unlink(tgz_path) - def build(self): + def build(self) -> None: self.maybe_download() # workaround to get self.n_samples() work in ConceptDataset. We will overwrite later in super().__init__() @@ -469,18 +469,18 @@ def build(self): torch.save(embeddings, self.processed_paths[2]) - def load_raw(self): + def load_raw(self) -> Tuple[torch.Tensor, pd.DataFrame, Annotations, None]: self.maybe_build() concepts = torch.load(self.processed_paths[0], weights_only=False) annotations = torch.load(self.processed_paths[1], weights_only=False) embeddings = torch.load(self.processed_paths[2], weights_only=False) return embeddings, concepts, annotations, None - def load(self): + def load(self) -> Tuple[torch.Tensor, pd.DataFrame, Annotations, None]: embeddings, concepts, annotations, graph = self.load_raw() return embeddings, concepts, annotations, graph - def __getitem__(self, idx) -> Dict[str, Dict[str, torch.Tensor]]: + def __getitem__(self, idx: int) -> Dict[str, Dict[str, torch.Tensor]]: img_rel_path = pd.read_csv( # TODO: optimize by reading this once in __init__ self.raw_paths[0], header=None, From 26091a1c081f2640d526855fb02446c018047fb1 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Thu, 27 Nov 2025 18:32:44 +0100 Subject: [PATCH 06/16] Fix style --- torch_concepts/data/datasets/cub.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index eeeb5225..78105182 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -351,11 +351,11 @@ class CUBDataset(ConceptDataset): def __init__( self, - precision : int = 32, - concepts : np.ndarray | pd.DataFrame | torch.Tensor = None, - annotations : Annotations | None = None, - concept_names_subset : List[str] | None = None, - root : str = CUB_DIR, + precision: int = 32, + concepts: np.ndarray | pd.DataFrame | torch.Tensor = None, + annotations: Annotations | None = None, + concept_names_subset: List[str] | None = None, + root: str = CUB_DIR, image_transform: object | None = None, ) -> None: self.root = root From e53a66244e4cdaa7d6421d45e995bd0d2ed03081 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Fri, 28 Nov 2025 15:43:34 +0100 Subject: [PATCH 07/16] Remove input_data from cub and superclass --- torch_concepts/data/base/dataset.py | 7 +------ torch_concepts/data/datasets/cub.py | 1 - 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/torch_concepts/data/base/dataset.py b/torch_concepts/data/base/dataset.py index 43f9ed82..5ed23da8 100644 --- a/torch_concepts/data/base/dataset.py +++ b/torch_concepts/data/base/dataset.py @@ -63,7 +63,7 @@ class ConceptDataset(Dataset): """ def __init__( self, - input_data: Union[np.ndarray, pd.DataFrame, Tensor], + input_data: Union[np.ndarray, pd.DataFrame, Tensor, None], concepts: Union[np.ndarray, pd.DataFrame, Tensor], annotations: Optional[Annotations] = None, graph: Optional[pd.DataFrame] = None, @@ -127,11 +127,6 @@ def __init__( self.maybe_reduce_annotations(annotations, concept_names_subset) - # Set dataset's input data X - # TODO: input is assumed to be a one of "np.ndarray, pd.DataFrame, Tensor" for now - # allow more complex data structures in the future with a custom parser - self.input_data: Tensor = parse_tensor(input_data, 'input', self.precision) - # Store concept data C self.concepts = None if concepts is not None: diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index 78105182..7cb37f7d 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -366,7 +366,6 @@ def __init__( super().__init__( precision=precision, - input_data=embeddings, concepts=concepts, annotations=annotations, graph=graph, From 4e3fda89c76a1a8304fca52f5fbcc0bb9c762377 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Fri, 28 Nov 2025 19:10:09 +0100 Subject: [PATCH 08/16] Remove embedding computation in CUB's build(.) --- torch_concepts/data/datasets/cub.py | 47 +++++++---------------------- 1 file changed, 11 insertions(+), 36 deletions(-) diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index 7cb37f7d..2e60902f 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -408,22 +408,17 @@ def download(self) -> None: tar.extractall(path=self.root) os.unlink(tgz_path) - def build(self) -> None: + def build(self): self.maybe_download() - - # workaround to get self.n_samples() work in ConceptDataset. We will overwrite later in super().__init__() - # create a torch tensor with shape (n_samples, whatever) and set self.input_data to it temporarily - temp_input_data = torch.zeros((11788, 10)) # CUB has 11788 samples - self.input_data = temp_input_data - images = pd.read_csv( - self.raw_paths[0], - sep=r"\s+", - header=None, - names=["image_id", "path"], - ) + images = pd.read_csv(self.raw_paths[0], sep=r"\s+", header=None, names=['image_id', 'path']) + image_paths = images.set_index('image_id')['path'] + image_paths = image_paths.apply(lambda p: os.path.join(self.root, "CUB_200_2011", "images", p)) + + # attribute names: use canonical order from CONCEPT_SEMANTICS (matches attr_id 1..312) concept_names = CONCEPT_SEMANTICS + # image_attribute_labels.txt has 6 columns; we only need is_present (col 3) attr_labels = pd.read_csv( self.raw_paths[5], header=None, @@ -433,7 +428,7 @@ def build(self) -> None: engine="python", ) concepts_df = attr_labels.pivot(index='image_id', columns='attr_id', values='is_present').fillna(0) - concepts_df = concepts_df.loc[images["image_id"]] + concepts_df = concepts_df.loc[image_paths.index] concepts_tensor = torch.tensor(concepts_df.values, dtype=torch.float32) concept_metadata = {name: {'type': 'discrete'} for name in concept_names} @@ -444,30 +439,10 @@ def build(self) -> None: metadata=concept_metadata) }) - torch.save(concepts_tensor, self.processed_paths[0]) - torch.save(annotations, self.processed_paths[1]) - - annotations = torch.load(self.processed_paths[1], weights_only=False) - self._annotations = annotations - self.maybe_reduce_annotations(annotations, None) - concepts = torch.load(self.processed_paths[0], weights_only=False) - # temporary placeholder so set_concepts has a length reference - self.input_data = torch.zeros((concepts.shape[0], 1)) - self.precision = 32 # set precision before calling set_concepts - self.set_concepts(concepts) - - # Compute embeddings using a pretrained model (e.g., ResNet) as backbone from torch_concepts.data.backbone - backbone = torch.nn.Sequential(*list(resnet18(pretrained=True).children())[:-1]) - embeddings = compute_backbone_embs( - self, - backbone, - batch_size=64, - workers=4, - verbose=True - ) + torch.save(list(image_paths.values), self.processed_paths[0]) + torch.save(concepts_tensor, self.processed_paths[1]) + torch.save(annotations, self.processed_paths[2]) - torch.save(embeddings, self.processed_paths[2]) - def load_raw(self) -> Tuple[torch.Tensor, pd.DataFrame, Annotations, None]: self.maybe_build() concepts = torch.load(self.processed_paths[0], weights_only=False) From c9fb562132ca91f2b3668af28713a7134d63cc17 Mon Sep 17 00:00:00 2001 From: francescoTheSantis Date: Fri, 24 Apr 2026 13:30:11 +0000 Subject: [PATCH 09/16] cub implementation --- .coveragerc | 3 + conceptarium/conf/dataset/cub.yaml | 25 + conceptarium/conf/dataset/cub_incomplete.yaml | 48 ++ conceptarium/conf/sweep.yaml | 12 +- torch_concepts/data/__init__.py | 6 +- torch_concepts/data/datamodules/awa2.py | 4 + torch_concepts/data/datamodules/cub.py | 96 ++++ torch_concepts/data/datasets/awa2.py | 69 +-- torch_concepts/data/datasets/bnlearn.py | 4 +- torch_concepts/data/datasets/cub.py | 461 ++++++++++-------- torch_concepts/data/io.py | 62 ++- 11 files changed, 511 insertions(+), 279 deletions(-) create mode 100644 conceptarium/conf/dataset/cub.yaml create mode 100644 conceptarium/conf/dataset/cub_incomplete.yaml create mode 100644 torch_concepts/data/datamodules/cub.py diff --git a/.coveragerc b/.coveragerc index c67adeb7..fb911b64 100644 --- a/.coveragerc +++ b/.coveragerc @@ -12,11 +12,14 @@ omit = */torch_concepts/data/datasets/mnist_arithmetic.py */torch_concepts/data/datasets/pendulum.py */torch_concepts/data/datasets/awa2.py + */torch_concepts/data/datasets/cub.py + # Exluding torch_concepts/data/datamodules/dataset_file.py */torch_concepts/data/datamodules/dsprites_regression.py */torch_concepts/data/datamodules/mnist_arithmetic.py */torch_concepts/data/datamodules/pendulum.py */torch_concepts/data/datamodules/awa2.py + */torch_concepts/data/datamodules/cub.py [report] exclude_lines = diff --git a/conceptarium/conf/dataset/cub.yaml b/conceptarium/conf/dataset/cub.yaml new file mode 100644 index 00000000..613dd37d --- /dev/null +++ b/conceptarium/conf/dataset/cub.yaml @@ -0,0 +1,25 @@ +defaults: + - _commons + - _self_ + +_target_: torch_concepts.data.datamodules.cub.CUBDataModule + +name: cub + +# Image resize size +image_size: 224 + +# backbone handling and embedding precomputation +backbone: resnet50 +precompute_embs: true +force_recompute: false + +# Task label - bird species (200 classes) +default_task_names: [class] + +# splitter - CUB has official train/test +splitter: + _target_: torch_concepts.data.splitters.native.NativeSplitter + +# Concept descriptions (optional; leave null to use raw attribute names) +label_descriptions: null diff --git a/conceptarium/conf/dataset/cub_incomplete.yaml b/conceptarium/conf/dataset/cub_incomplete.yaml new file mode 100644 index 00000000..b90e6685 --- /dev/null +++ b/conceptarium/conf/dataset/cub_incomplete.yaml @@ -0,0 +1,48 @@ +defaults: + - _commons + - _self_ + +_target_: torch_concepts.data.datamodules.cub.CUBDataModule + +name: cub + +# Image resize size +image_size: 224 + +# backbone handling and embedding precomputation +backbone: resnet50 +precompute_embs: true +force_recompute: false + +# Task label - bird species (200 classes) +default_task_names: [class] + +# splitter - CUB has official train/test +splitter: + _target_: torch_concepts.data.splitters.native.NativeSplitter + +# We generated the CUB incomplete dataset following the same procedure as in +# Zarlenga et al. (2024) "Avoiding Leakage Poisoning: Concept Interventions Under Distribution Shifts" (https://arxiv.org/pdf/2504.17921v1). +# More precisely, selecting the concepts belonging to the following groups: +# [“has_bill_shape”, “has_head_pattern”, “has_breast_colour”, “has_bill_length”, “has_wing_shape”, “has_tail_pattern”, “has_bill_color”] +concept_subset: [ + 'has_bill_shape::dagger', + 'has_bill_shape::hooked_seabird', + 'has_bill_shape::all-purpose', + 'has_bill_shape::cone', + 'has_head_pattern::eyebrow', + 'has_head_pattern::plain', + 'has_bill_length::about_the_same_as_head', + 'has_bill_length::shorter_than_head', + 'has_wing_shape::rounded-wings', + 'has_wing_shape::pointed-wings', + 'has_tail_pattern::solid', + 'has_tail_pattern::striped', + 'has_tail_pattern::multi-colored', + 'has_bill_color::grey', + 'has_bill_color::black', + 'has_bill_color::buff' +] + +# Concept descriptions (optional; leave null to use raw attribute names) +label_descriptions: null diff --git a/conceptarium/conf/sweep.yaml b/conceptarium/conf/sweep.yaml index 0843f122..b939c5dc 100644 --- a/conceptarium/conf/sweep.yaml +++ b/conceptarium/conf/sweep.yaml @@ -9,12 +9,10 @@ hydra: # standard grid search params: seed: 42 - dataset: dag_asia, dag_sachs, dag_insurance - model: cbm, cem, c2bm + dataset: cub, cub_incomplete, awa2, awa2_incomplete + model: cbm, cem model.train_inference._target_: - torch_concepts.nn.DeterministicInference, - torch_concepts.nn.IndependentInference, - torch_concepts.nn.AncestralSamplingInference + torch_concepts.nn.DeterministicInference # --- inference params # +model.train_inference.detach: false, true # +model.train_inference.p: 0.4 @@ -24,7 +22,7 @@ hydra: # loss.task_weight: 1 dataset: - batch_size: 2048 + batch_size: 128 # concept_subset: ['Attractive', 'Bald', 'Big_Nose', 'Black_Hair'] # for celeba model: @@ -45,7 +43,7 @@ trainer: # logger: wandb # log_model: true # whether to save checkpoint on wandb save_top_k: 1. # whether to save checkpoint locally - max_epochs: 200 + max_epochs: 2 patience: 20 matmul_precision: medium diff --git a/torch_concepts/data/__init__.py b/torch_concepts/data/__init__.py index 51e799e4..ea8b82c9 100644 --- a/torch_concepts/data/__init__.py +++ b/torch_concepts/data/__init__.py @@ -32,6 +32,7 @@ from .datasets.mnist_arithmetic import MNISTArithmeticDataset from .datasets.dsprites_regression import DSpritesRegressionDataset from .datasets.awa2 import AWA2Dataset +from .datasets.cub import CUBDataset # Re-export datamodules for convenient access from .datamodules.bnlearn import BnLearnDataModule @@ -42,6 +43,7 @@ from .datamodules.mnist_arithmetic import MNISTArithmeticDataModule from .datamodules.dsprites_regression import DSpritesRegressionDataModule from .datamodules.awa2 import AWA2DataModule +from .datamodules.cub import CUBDataModule __all__ = [ # Submodules @@ -66,7 +68,8 @@ "MNISTArithmeticDataset", "DSpritesRegressionDataset", "AWA2Dataset", - + "CUBDataset", + # DataModules "BnLearnDataModule", "ToyDAGDataModule", @@ -76,4 +79,5 @@ "MNISTArithmeticDataModule", "DSpritesRegressionDataModule", "AWA2DataModule", + "CUBDataModule", ] diff --git a/torch_concepts/data/datamodules/awa2.py b/torch_concepts/data/datamodules/awa2.py index eea580fc..cb1482e3 100644 --- a/torch_concepts/data/datamodules/awa2.py +++ b/torch_concepts/data/datamodules/awa2.py @@ -21,6 +21,8 @@ class AWA2DataModule(ConceptDataModule): Default: ``None`` (auto-creates ``./data/AWA2``). seed : int, optional Random seed for train / val / test split. Default: 42. + image_size : int, optional + Side length (px) to resize images to. Default: 224. val_size : float, optional Fraction of samples for validation. Default: 0.1. test_size : float, optional @@ -67,6 +69,7 @@ def __init__( self, root: str = None, seed: int = 42, + image_size: int = 224, val_size: float = 0.1, test_size: float = 0.2, splitter: Splitter = RandomSplitter(), @@ -83,6 +86,7 @@ def __init__( root=root, concept_subset=concept_subset, label_descriptions=label_descriptions, + image_size=image_size, ) super().__init__( diff --git a/torch_concepts/data/datamodules/cub.py b/torch_concepts/data/datamodules/cub.py new file mode 100644 index 00000000..29d3c008 --- /dev/null +++ b/torch_concepts/data/datamodules/cub.py @@ -0,0 +1,96 @@ +from ..datasets.cub import CUBDataset + +from ..base.datamodule import ConceptDataModule +from ...typing import BackboneType +from ..base.splitter import Splitter +from ..splitters.native import NativeSplitter + + +class CUBDataModule(ConceptDataModule): + """DataModule for CUB-200-2011 (Caltech-UCSD Birds). + + Handles data loading, splitting, and batching for the CUB-200-2011 dataset + with support for concept-based learning. CUB-200-2011 provides official + train / val / test splits via the Koh et al. pre-processed pickle files, + so :class:`~torch_concepts.data.splitters.NativeSplitter` is used by + default. + + .. note:: + CUB-200-2011 must be **manually downloaded** before use. + See :class:`~torch_concepts.data.datasets.CUBDataset` for instructions. + + Parameters + ---------- + root : str, optional + Root directory containing ``class_attr_data_10/`` and + ``CUB_200_2011/``. Default: ``None`` (auto-creates ``./data/CUB200``). + image_size : int, optional + Side length (px) to resize images to. Default: 224. + splitter : Splitter, optional + Splitting strategy. Default: ``NativeSplitter()`` (uses the official + train / val / test splits from the pickle files). + batch_size : int, optional + Number of samples per batch. Default: 512. + backbone : BackboneType, optional + Backbone model for feature extraction (e.g. ``'resnet50'``). + Default: ``None``. + precompute_embs : bool, optional + Whether to precompute and cache backbone embeddings. Default: ``True``. + force_recompute : bool, optional + Recompute embeddings even if a cache exists. Default: ``False``. + concept_subset : list of str, optional + Subset of concept names to retain. Default: ``None`` (all 113). + label_descriptions : dict, optional + Mapping from concept name to human-readable description. + workers : int, optional + Number of data-loading worker processes. Default: 0. + + Examples + -------- + >>> from torch_concepts.data import CUBDataModule + >>> + >>> dm = CUBDataModule( + ... root="./data/CUB200", + ... backbone="resnet50", + ... precompute_embs=True, + ... batch_size=64, + ... ) + >>> dm.setup() + >>> train_loader = dm.train_dataloader() + + See Also + -------- + CUBDataset : The underlying dataset class. + ConceptDataModule : Parent class with common datamodule functionality. + """ + + def __init__( + self, + root: str = None, + image_size: int = 224, + splitter: Splitter = NativeSplitter(), + batch_size: int = 512, + backbone: BackboneType = None, + precompute_embs: bool = True, + force_recompute: bool = False, + concept_subset: list | None = None, + label_descriptions: dict | None = None, + workers: int = 0, + **kwargs, + ): + dataset = CUBDataset( + root=root, + image_size=image_size, + concept_subset=concept_subset, + label_descriptions=label_descriptions, + ) + + super().__init__( + dataset=dataset, + batch_size=batch_size, + backbone=backbone, + precompute_embs=precompute_embs, + force_recompute=force_recompute, + workers=workers, + splitter=splitter, + ) diff --git a/torch_concepts/data/datasets/awa2.py b/torch_concepts/data/datasets/awa2.py index 0581ce95..6b66bdda 100644 --- a/torch_concepts/data/datasets/awa2.py +++ b/torch_concepts/data/datasets/awa2.py @@ -18,12 +18,13 @@ import torchvision.transforms as transforms from PIL import Image from typing import List, Mapping, Optional -import urllib.request import zipfile from pathlib import Path from torch_concepts import Annotations, AxisAnnotation from torch_concepts.data.base import ConceptDataset +from torch_concepts.data.io import download_file, zip_is_valid + logger = logging.getLogger(__name__) @@ -233,6 +234,8 @@ class AWA2Dataset(ConceptDataset): root : str, optional Root directory where the dataset is stored. Defaults to ``./data/AWA2``. + image_size : int, optional + Side length (px) to resize images to. Default: 224. concept_subset : list of str, optional Subset of concept names to retain. ``None`` keeps all 86. label_descriptions : dict, optional @@ -242,15 +245,14 @@ class AWA2Dataset(ConceptDataset): def __init__( self, root: str = None, - seed: int = 42, - train_size: float = 0.6, - val_size: float = 0.2, + image_size: int = 224, concept_subset: Optional[list] = None, label_descriptions: Optional[Mapping] = None, ): if root is None: root = os.path.join(os.getcwd(), 'data', 'AWA2') self.root = root + self.image_size = image_size self.label_descriptions = label_descriptions filenames, concepts, annotations, graph = self.load() @@ -297,59 +299,6 @@ def processed_filenames(self) -> List[str]: 'annotations.pt', ] - @staticmethod - def _zip_is_valid(path: str) -> bool: - """Return True if *path* is a structurally valid zip with correct CRCs.""" - try: - with zipfile.ZipFile(path) as z: - bad = z.testzip() # returns first bad filename or None - return bad is None - except zipfile.BadZipFile: - return False - - @staticmethod - def _wget_available() -> bool: - import shutil as _shutil - return _shutil.which("wget") is not None - - def _download_file(self, url: str, dest: str) -> None: - """Download *url* to *dest*. - - Uses ``wget --continue`` when available (handles large files and - network interruptions much better than urllib). Falls back to a - pure-Python streaming download with ``Range`` resume support. - """ - if self._wget_available(): - import subprocess - print(f"\nDownloading {os.path.basename(dest)} via wget ...") - subprocess.run( - [ - "wget", - "--continue", # resume partial downloads - "--tries=10", # retry up to 10 times on error - "--retry-connrefused", - "--waitretry=5", # wait 5 s between retries - "--show-progress", - "-O", dest, - url, - ], - check=True, - ) - else: - downloaded = os.path.getsize(dest) if os.path.exists(dest) else 0 - req = urllib.request.Request(url, headers={"Range": f"bytes={downloaded}-"}) - with urllib.request.urlopen(req) as r: - total = downloaded + int(r.headers.get("Content-Length", 0)) - print(f"\nDownloading {os.path.basename(dest)} ({total / 1e9:.2f} GB)") - with open(dest, "ab") as f: - while chunk := r.read(1024 * 1024): - f.write(chunk) - downloaded += len(chunk) - pct = downloaded / total * 100 - bar = "█" * int(pct // 2) + "░" * (50 - int(pct // 2)) - print(f"\r [{bar}] {pct:.1f}% {downloaded/1e9:.2f}/{total/1e9:.2f} GB", end="") - print() - def download(self): """Download raw AwA2 data from official sources. @@ -364,9 +313,9 @@ def download(self): for url in URLS: dest = os.path.join(self.root, url.split("/")[-1]) for attempt in range(1, _MAX_RETRIES + 1): - self._download_file(url, dest) + download_file(url, dest) print(f" Verifying {os.path.basename(dest)} (attempt {attempt}/{_MAX_RETRIES}) ...") - if self._zip_is_valid(dest): + if zip_is_valid(dest): break print(f" CRC check failed — deleting corrupted file and retrying ...") os.remove(dest) @@ -498,7 +447,7 @@ def __getitem__(self, item: int) -> dict: img_path = self.input_data[item] x = Image.open(img_path) x = x.convert('RGB') # Ensure 3 channels - x = transforms.Resize((224, 224))(x) # Resize to 224x224 + x = transforms.Resize((self.image_size, self.image_size))(x) # Resize to 224x224 x = transforms.ToTensor()(x) # Convert to tensor and scale to [0, 1] c = self.concepts[item] return {'inputs': {'x': x}, 'concepts': {'c': c}} diff --git a/torch_concepts/data/datasets/bnlearn.py b/torch_concepts/data/datasets/bnlearn.py index fc1ab7a6..1fe3b245 100644 --- a/torch_concepts/data/datasets/bnlearn.py +++ b/torch_concepts/data/datasets/bnlearn.py @@ -14,7 +14,7 @@ from ..base import ConceptDataset from ..preprocessing.autoencoder import extract_embs_from_autoencoder -from ..io import download_url +from ..io import download_urllib BUILTIN_DAGS = ['asia', 'alarm', 'andes', 'sachs', 'water'] @@ -87,7 +87,7 @@ def download(self): pass else: url = f'https://www.bnlearn.com/bnrepository/{self.name}/{self.name}.bif.gz' - gz_path = download_url(url, self.root_dir) + gz_path = download_urllib(url, self.root_dir) bif_path = self.raw_paths[0] # Decompress .gz file diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index 1b6636d3..5bcd9ec9 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -1,47 +1,33 @@ """ -CUB-200 Dataset Loader -** THIS DATASET NEEDS TO BE DOWNLOADED BEFORE BEING ABLE TO USE THE LOADER ** +CUB-200-2011 (Caltech-UCSD Birds) Dataset -################################################################################ -## DOWNLOAD INSTRUCTIONS -################################################################################ - -**** OPTION #1 ***** -The simplest way to get the CUB dataset, is to download the pre-processed CUB -dataset by Koh et al. [CBM Paper]. This can be downloaded from their -public colab notebook at: https://worksheets.codalab.org/worksheets/0x362911581fcd4e048ddfd84f47203fd2. -You will need to download the original CUB dataset from that notebook (found -here: https://worksheets.codalab.org/bundles/0xd013a7ba2e88481bbc07e787f73109f5) -and the preprocessed "CUB_preprocessed" dataset (which can be directly accessed -here: https://worksheets.codalab.org/bundles/0x5b9d528d2101418b87212db92fea6683) - -**** OPTION #2 ***** -Follow the download the preprocess instructions found in Koh et al.'s original -repository here: https://github.com/yewsiang/ConceptBottleneck. -Specifically, here: https://github.com/yewsiang/ConceptBottleneck/blob/master/CUB/ - -################################################################################ - -[IMPORTANT] After downloading the files, they need to follow the following -structure: - - -This loader has been adapted/inspired by that of found in Koh et al.'s -repository (https://github.com/yewsiang/ConceptBottleneck/blob/master/CUB/cub_loader.py) -as well as in Espinosa Zarlenga and Barbiero's et al.'s repository -(https://github.com/mateoespinosa/cem). +Adapted from: + - Koh et al.'s paper Concept Bottleneck Models + - Espinosa Zarlenga and Barbiero et al.'s repository https://github.com/mateoespinosa/cem/blob/main/cem/data/CUB200/cub_loader.py. """ - -import numpy as np import os +import logging +from pathlib import Path +import tarfile +from anyio import Path import pickle +import numpy as np +import pandas as pd import torch import torchvision.transforms as transforms from collections import defaultdict from PIL import Image -from torch.utils.data import Dataset +from typing import List, Mapping, Optional +import zipfile +import shutil + +from torch_concepts import Annotations, AxisAnnotation +from torch_concepts.data.base import ConceptDataset +from torch_concepts.data.io import download_urllib + +logger = logging.getLogger(__name__) ######################################################## ## GENERAL DATASET GLOBAL VARIABLES @@ -49,13 +35,11 @@ N_CLASSES = 200 -# CAN BE OVERWRITTEN WITH AN ENV VARIABLE CUB_DIR -CUB_DIR = os.environ.get("CUB_DIR", './CUB200/') - - -######################################################### -## CONCEPT INFORMATION REGARDING CUB -######################################################### +URLS = [ + # NOTE: we retrieve the .pkl split files from the CEM repository since I cannot find the m in the CBM repo. + "https://raw.githubusercontent.com/mateoespinosa/cem/main/cem/data/CUB200/class_attr_data_10", + "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz?download=1", +] # CUB Class names @@ -701,188 +685,255 @@ )): group = concept_name[:concept_name.find("::")] CONCEPT_GROUP_MAP[group].append(i) +CONCEPT_GROUP_MAP = dict(CONCEPT_GROUP_MAP) +# Ordered names for the 112 selected concepts (matches order in pkl files) +SELECTED_CONCEPT_NAMES: List[str] = [CONCEPT_SEMANTICS[i] for i in SELECTED_CONCEPTS] -# Definitions from CUB (certainties.txt) -# 1 not visible -# 2 guessing -# 3 probably -# 4 definitely -# Unc map represents a mapping from the discrete score to a "mental probability" -DEFAULT_UNC_MAP = [ - {0: 0.5, 1: 0.5, 2: 0.5, 3:0.75, 4:1.0}, - {0: 0.5, 1: 0.5, 2: 0.5, 3:0.75, 4:1.0}, -] - -########################################################## -## Helper Functions -########################################################## - +class CUBDataset(ConceptDataset): + """Dataset class for CUB-200-2011 (Caltech-UCSD Birds). -def discrete_to_continuous_unc(unc_val, attr_label, unc_map): - ''' - Yield a continuous prob representing discrete conf val - Inspired by CBM data processing + CUB-200-2011 contains 11,788 bird images across 200 species classes, + annotated with 112 binary semantic attributes selected by Koh et al. + [CBM Paper] from the full set of 312 CUB attributes. - The selected probability should account for whether the concept is on or off - E.g., if a human is "probably" sure the concept is off - flip the prob in unc_map - ''' - unc_val = unc_val.item() - attr_label = attr_label.item() - return float(unc_map[int(attr_label)][unc_val]) + Official train / val / test splits from the pre-processed pickle files + are preserved; use :class:`~torch_concepts.data.splitters.NativeSplitter` + in the corresponding datamodule. + The concept vector per sample contains: -########################################################## -## Data Loaders -########################################################## + - columns 0-111: 112 binary semantic attributes (cardinality 1 each) + - column 112: bird species index 0-199 (cardinality 200) -class CUBDataset(Dataset): - """ - TODO + Parameters + ---------- + root : str, optional + Root directory that contains ``class_attr_data_10/`` and + ``CUB_200_2011/``. Defaults to ``./data/CUB200``. + image_size : int, optional + Side length (px) images are resized to. Defaults to 224. + concept_subset : list of str, optional + Subset of concept names to retain. ``None`` keeps all 113. + label_descriptions : dict, optional + Mapping from concept name to human-readable description. """ def __init__( self, - split='train', - uncertain_concept_labels=False, - root=CUB_DIR, - path_transform=None, - sample_transform=None, - concept_transform=None, - label_transform=None, - uncertainty_based_random_labels=False, - unc_map=DEFAULT_UNC_MAP, - selected_concepts=None, - training_augment=True, + root: str = None, + image_size: int = 224, + concept_subset: Optional[list] = None, + label_descriptions: Optional[Mapping] = None, ): - """ - TODO: Define different arguments - """ - if not (os.path.exists(root) and os.path.isdir(root)): - raise ValueError( - f'Provided CUB data directory "{root}" is not a valid or ' - f'an existing directory.' - ) - assert split in ['train', 'val', 'test'], ( - f"CUB split must be in ['train', 'val', 'test'] but got '{split}'" - ) - self.split = split - base_dir = os.path.join(root, 'class_attr_data_10') - self.pkl_file_path = os.path.join(base_dir, f'{split}.pkl') - self.name = 'CUB' - - self.data = [] - with open(self.pkl_file_path, 'rb') as f: - self.data.extend(pickle.load(f)) - image_size = 299 - if (split == 'train') and training_augment: - self.sample_transform = transforms.Compose([ - transforms.ColorJitter(brightness=32/255, saturation=(0.5, 1.5)), - transforms.RandomResizedCrop(image_size), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), #implicitly divides by 255 - transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [2, 2, 2]), - sample_transform or (lambda x: x), - ]) - else: - self.sample_transform = transforms.Compose([ - transforms.CenterCrop(image_size), - transforms.ToTensor(), #implicitly divides by 255 - transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [2, 2, 2]), - sample_transform or (lambda x: x), - ]) - self.concept_transform = concept_transform or (lambda x: x) - self.label_transform = label_transform or (lambda x: x) - self.uncertain_concept_labels = uncertain_concept_labels + if root is None: + root = os.path.join(os.getcwd(), 'data', 'CUB200') self.root = root - self.path_transform = path_transform - self.uncertainty_based_random_labels = uncertainty_based_random_labels - self.unc_map = unc_map - if selected_concepts is None: - selected_concepts = list(range(len(SELECTED_CONCEPTS))) - self.selected_concepts = selected_concepts - self.concept_names = self.concept_attr_names = list( - np.array( - CONCEPT_SEMANTICS - )[CONCEPT_SEMANTICS][selected_concepts] + self.image_size = image_size + self.label_descriptions = label_descriptions + + filenames, concepts, annotations, graph = self.load() + + super().__init__( + input_data=filenames, + concepts=concepts, + annotations=annotations, + graph=graph, + concept_names_subset=concept_subset, + name='CUBDataset', ) - self.task_names = self.task_attr_names = CLASS_NAMES - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - img_data = self.data[idx] - img_path = img_data['img_path'] - if self.path_transform is None: - # This is needed if the dataset is downloaded from the original - # CBM paper's repository/experiment code - img_path = img_path.replace( - '/juice/scr/scr102/scr/thaonguyen/CUB_supervision/datasets/', - self.root + + # ------------------------------------------------------------------ + # ConceptDataset interface + # ------------------------------------------------------------------ + + @property + def raw_filenames(self) -> List[str]: + return [ + "attributes", + "images", + "parts", + "attributes.txt", + "bounding_boxes.txt", + "classes.txt", + "image_class_labels.txt", + "images.txt", + "train_test_split.txt", # split with left out classes (not used in our setting) + "class_attr_data_10/train.pkl", # splits with all classes, from Koh et al.'s pre-processing + "class_attr_data_10/val.pkl", + "class_attr_data_10/test.pkl", + ] + + @property + def processed_filenames(self) -> List[str]: + return [ + 'filenames.txt', + 'concepts.pt', + 'annotations.pt', + 'split_mapping.h5', + ] + + def download(self) -> None: + """Downloads the CUB dataset if it is not already present.""" + if not os.path.exists(self.root): + os.makedirs(self.root) + + # store the Koh et al. pre-processed pickle files in a subfolder "class_attr_data_10" + class_attr_dir = os.path.join(self.root, "class_attr_data_10") + if not os.path.exists(class_attr_dir): + os.makedirs(class_attr_dir) + for split_name in ('train', 'val', 'test'): + url = f"{URLS[0]}/{split_name}.pkl" + download_urllib(url, class_attr_dir) + + tgz_path = download_urllib(URLS[1], self.root) + + with tarfile.open(tgz_path, "r:gz") as tar: + tar.extractall(path=self.root) + os.unlink(tgz_path) + + # Move all the files outside of the nested "CUB_200_2011" folder to the root + extracted_folder = os.path.join(self.root, "CUB_200_2011") + for item in os.listdir(extracted_folder): + src = os.path.join(extracted_folder, item) + dst = os.path.join(self.root, item) + if os.path.exists(dst): + if os.path.isdir(dst): + shutil.rmtree(dst) + else: + os.remove(dst) + shutil.move(src, dst) + os.rmdir(extracted_folder) + + def _remap_image_path(self, img_path: str) -> str: + """Remap the absolute path stored in a pkl entry to the local root. + + The Koh et al. pkl files embed absolute paths from their cluster + (``/juice/scr/.../datasets/``). We extract the ``CUB_200_2011/`` + subtree and join it with the local root. + """ + marker = 'CUB_200_2011' + idx = img_path.find(marker) + if idx >= 0: + relative = img_path[idx:] # e.g. "CUB_200_2011/images/.../file.jpg" + # Eliminate CUB_200_2011 from path + relative = relative[len(marker) + 1:] # e.g. "images/.../file.jpg" + return os.path.abspath(os.path.join(self.root, relative)) + # Fallback: replace the known cluster prefix + return img_path.replace( + '/juice/scr/scr102/scr/thaonguyen/CUB_supervision/datasets/', + self.root + os.sep, + ) + + def build(self): + """Process raw CUB pickle files and save cached dataset artefacts.""" + self.maybe_download() + + logger.info(f"Building CUB dataset from {self.root} ...") + + all_paths: List[str] = [] + all_attrs: List[List[int]] = [] + all_classes: List[int] = [] + split_labels: List[str] = [] + + for split_name in ('train', 'val', 'test'): + pkl_path = os.path.join( + self.root, 'class_attr_data_10', f'{split_name}.pkl' ) - try: - img = Image.open(img_path).convert('RGB') - except: - img_path_split = img_path.split('/') - img_path = '/'.join( - img_path_split[:2] + [self.split] + img_path_split[2:] - ) - img = Image.open(img_path).convert('RGB') - else: - img = Image.open(self.path_transform(img_path)).convert('RGB') + with open(pkl_path, 'rb') as fh: + entries = pickle.load(fh) + + for entry in entries: + img_path = self._remap_image_path(entry['img_path']) + all_paths.append(img_path) + all_attrs.append(entry['attribute_label']) # 112-dim list + all_classes.append(int(entry['class_label'])) + split_labels.append(split_name) + + n = len(all_paths) + logger.info(f"Loaded {n} samples (train/val/test)") + + # Build concept tensor: 112 binary attrs + class index + attr_array = np.array(all_attrs, dtype=np.float32) # (n, 112) + class_array = np.array(all_classes, dtype=np.float32).reshape(-1, 1) # (n, 1) + all_concepts_np = np.concatenate([attr_array, class_array], axis=1) # (n, 113) + concepts_tensor = torch.tensor(all_concepts_np, dtype=torch.float32) + + # Build Annotations + concept_names = SELECTED_CONCEPT_NAMES + ['class'] + binary_states = [['0'] for _ in SELECTED_CONCEPT_NAMES] + states = binary_states + [CLASS_NAMES] + cardinalities = [1] * len(SELECTED_CONCEPT_NAMES) + [N_CLASSES] + concept_metadata = {name: {'type': 'discrete'} for name in concept_names} + + annotations = Annotations({ + 1: AxisAnnotation( + labels=concept_names, + states=states, + cardinalities=cardinalities, + metadata=concept_metadata, + ) + }) + + # Build split mapping (native train/val/test) + split_series = pd.Series(split_labels, name='split') + + # Save artefacts + os.makedirs(self.root, exist_ok=True) + logger.info(f"Saving CUB dataset artefacts to {self.root}") + + with open(self.processed_paths[0], 'w') as fh: + fh.write('\n'.join(all_paths)) + + torch.save(concepts_tensor, self.processed_paths[1]) + torch.save(annotations, self.processed_paths[2]) + split_series.to_hdf(self.processed_paths[3], key='split_mapping') + + logger.info(f"CUB dataset saved ({n} samples)") - class_label = self.label_transform(img_data['class_label']) - img = self.sample_transform(img) + def load_raw(self): + """Load processed artefacts from disk.""" + self.maybe_build() - if self.uncertain_concept_labels: - attr_label = img_data['uncertain_attribute_label'] + logger.info(f"Loading CUB dataset from {self.root}") + + with open(self.processed_paths[0], 'r') as fh: + filenames = fh.read().strip().split('\n') + + concepts = torch.load(self.processed_paths[1], weights_only=False) + annotations = torch.load(self.processed_paths[2], weights_only=False) + graph = None + + return filenames, concepts, annotations, graph + + def load(self): + return self.load_raw() + + def __getitem__(self, item: int) -> dict: + if self.embs_precomputed: + x = self.input_data[item] else: - attr_label = img_data['attribute_label'] - attr_label = self.concept_transform( - np.array(attr_label)[self.selected_concepts] - ) + img_path = self.input_data[item] + x = Image.open(img_path).convert('RGB') + x = transforms.Resize((self.image_size, self.image_size))(x) + x = transforms.ToTensor()(x) + c = self.concepts[item] + return {'inputs': {'x': x}, 'concepts': {'c': c}} + + # ------------------------------------------------------------------ + # Properties — override base class which assumes input_data is a Tensor + # ------------------------------------------------------------------ + + @property + def n_samples(self) -> int: + return len(self.input_data) + + @property + def n_features(self) -> tuple: + return tuple(self[0]['inputs']['x'].shape) + + @property + def shape(self) -> tuple: + return (self.n_samples, *self.n_features) - # We may want to randomly sample concept labels based on their provided - # annotator uncertainty - if self.uncertainty_based_random_labels: - discrete_unc_label = np.array( - img_data['attribute_certainty'] - )[self.selected_concepts] - instance_attr_label = np.array(img_data['attribute_label']) - competencies = [] - for (discrete_unc_val, hard_concept_val) in zip( - discrete_unc_label, - instance_attr_label, - ): - competencies.append( - discrete_to_continuous_unc( - discrete_unc_val, - hard_concept_val, - self.unc_map, - ) - ) - attr_label = np.random.binomial(1, competencies) - - return img, torch.FloatTensor(attr_label), class_label - - def concept_weights(self): - """ - Calculate class imbalance ratio for binary attribute labels - """ - imbalance_ratio = [] - with open(self.pkl_file_path, 'rb') as f: - data = pickle.load(f) - n = len(data) - n_attr = len(data[0]['attribute_label']) - n_ones = [0] * n_attr - total = [n] * n_attr - for d in data: - labels = d['attribute_label'] - for i in range(n_attr): - n_ones[i] += labels[i] - for j in range(len(n_ones)): - imbalance_ratio.append(total[j]/n_ones[j] - 1) - return np.array(imbalance_ratio)[self.selected_concepts] \ No newline at end of file diff --git a/torch_concepts/data/io.py b/torch_concepts/data/io.py index 16a875a5..63cc1cbe 100644 --- a/torch_concepts/data/io.py +++ b/torch_concepts/data/io.py @@ -103,10 +103,10 @@ def update_to(self, b=1, bsize=1, tsize=None): self.update(b * bsize - self.n) -def download_url(url: str, - folder: str, - filename: Optional[str] = None, - verbose: bool = True): +def download_urllib(url: str, + folder: str, + filename: Optional[str] = None, + verbose: bool = True): r"""Downloads the content of an URL to a specific folder. Args: @@ -137,3 +137,57 @@ def download_url(url: str, disable=not verbose) as t: urllib.request.urlretrieve(url, filename=path, reporthook=t.update_to) return path + + +def zip_is_valid(path: str) -> bool: + """Return True if *path* is a structurally valid zip with correct CRCs.""" + try: + with zipfile.ZipFile(path) as z: + bad = z.testzip() # returns first bad filename or None + return bad is None + except zipfile.BadZipFile: + return False + + +def wget_available() -> bool: + import shutil as _shutil + return _shutil.which("wget") is not None + + +def download_file(url: str, dest: str) -> None: + """Download *url* to *dest*. + + Uses ``wget --continue`` when available (handles large files and + network interruptions much better than urllib). Falls back to a + pure-Python streaming download with ``Range`` resume support. + """ + if wget_available(): + import subprocess + print(f"\nDownloading {os.path.basename(dest)} via wget ...") + subprocess.run( + [ + "wget", + "--continue", # resume partial downloads + "--tries=10", # retry up to 10 times on error + "--retry-connrefused", + "--waitretry=5", # wait 5 s between retries + "--show-progress", + "-O", dest, + url, + ], + check=True, + ) + else: + downloaded = os.path.getsize(dest) if os.path.exists(dest) else 0 + req = urllib.request.Request(url, headers={"Range": f"bytes={downloaded}-"}) + with urllib.request.urlopen(req) as r: + total = downloaded + int(r.headers.get("Content-Length", 0)) + print(f"\nDownloading {os.path.basename(dest)} ({total / 1e9:.2f} GB)") + with open(dest, "ab") as f: + while chunk := r.read(1024 * 1024): + f.write(chunk) + downloaded += len(chunk) + pct = downloaded / total * 100 + bar = "█" * int(pct // 2) + "░" * (50 - int(pct // 2)) + print(f"\r [{bar}] {pct:.1f}% {downloaded/1e9:.2f}/{total/1e9:.2f} GB", end="") + print() \ No newline at end of file From ac947b7ab696c049fb97af4c4ce4cbbca36c617c Mon Sep 17 00:00:00 2001 From: francescoTheSantis Date: Fri, 24 Apr 2026 14:46:34 +0000 Subject: [PATCH 10/16] minor fix sweep --- conceptarium/conf/sweep.yaml | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/conceptarium/conf/sweep.yaml b/conceptarium/conf/sweep.yaml index b939c5dc..0843f122 100644 --- a/conceptarium/conf/sweep.yaml +++ b/conceptarium/conf/sweep.yaml @@ -9,10 +9,12 @@ hydra: # standard grid search params: seed: 42 - dataset: cub, cub_incomplete, awa2, awa2_incomplete - model: cbm, cem + dataset: dag_asia, dag_sachs, dag_insurance + model: cbm, cem, c2bm model.train_inference._target_: - torch_concepts.nn.DeterministicInference + torch_concepts.nn.DeterministicInference, + torch_concepts.nn.IndependentInference, + torch_concepts.nn.AncestralSamplingInference # --- inference params # +model.train_inference.detach: false, true # +model.train_inference.p: 0.4 @@ -22,7 +24,7 @@ hydra: # loss.task_weight: 1 dataset: - batch_size: 128 + batch_size: 2048 # concept_subset: ['Attractive', 'Bald', 'Big_Nose', 'Black_Hair'] # for celeba model: @@ -43,7 +45,7 @@ trainer: # logger: wandb # log_model: true # whether to save checkpoint on wandb save_top_k: 1. # whether to save checkpoint locally - max_epochs: 2 + max_epochs: 200 patience: 20 matmul_precision: medium From b0c3765ea52825f35fd1c2cd6f75332425b50034 Mon Sep 17 00:00:00 2001 From: francescoTheSantis Date: Fri, 24 Apr 2026 15:03:13 +0000 Subject: [PATCH 11/16] change sweep --- conceptarium/conf/sweep.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conceptarium/conf/sweep.yaml b/conceptarium/conf/sweep.yaml index 0843f122..440a4c03 100644 --- a/conceptarium/conf/sweep.yaml +++ b/conceptarium/conf/sweep.yaml @@ -9,7 +9,7 @@ hydra: # standard grid search params: seed: 42 - dataset: dag_asia, dag_sachs, dag_insurance + dataset: dag_asia, dag_sachs model: cbm, cem, c2bm model.train_inference._target_: torch_concepts.nn.DeterministicInference, From ac2a689bf1fe27cf72f41fa3051b2e0926ca0f8e Mon Sep 17 00:00:00 2001 From: francescoTheSantis Date: Fri, 24 Apr 2026 15:12:20 +0000 Subject: [PATCH 12/16] fix env and sweep --- conceptarium/conf/sweep.yaml | 2 +- conceptarium/environment.yaml | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/conceptarium/conf/sweep.yaml b/conceptarium/conf/sweep.yaml index 440a4c03..0843f122 100644 --- a/conceptarium/conf/sweep.yaml +++ b/conceptarium/conf/sweep.yaml @@ -9,7 +9,7 @@ hydra: # standard grid search params: seed: 42 - dataset: dag_asia, dag_sachs + dataset: dag_asia, dag_sachs, dag_insurance model: cbm, cem, c2bm model.train_inference._target_: torch_concepts.nn.DeterministicInference, diff --git a/conceptarium/environment.yaml b/conceptarium/environment.yaml index 6f3c43ff..0a4434d0 100644 --- a/conceptarium/environment.yaml +++ b/conceptarium/environment.yaml @@ -1,14 +1,12 @@ name: conceptarium channels: - - pytorch - - nvidia - conda-forge - - defaults + - nodefaults dependencies: - python=3.12.* - - - pytorch:pytorch - - pytorch:pytorch-cuda + + # Conda-forge's metapackage for PyTorch with CUDA support + - pytorch-gpu - torchvision>=0.17.1 - torchmetrics>=0.7 @@ -21,7 +19,6 @@ dependencies: - tqdm - scikit-learn - scipy - - tqdm - openpyxl - pip From bc29a3819bf711f88ff53c7ef846e960ca235c21 Mon Sep 17 00:00:00 2001 From: francescoTheSantis Date: Fri, 24 Apr 2026 15:20:43 +0000 Subject: [PATCH 13/16] upgrade actions/checkout and setup-python for Node 24 Co-authored-by: Copilot --- .github/workflows/coverage.yml | 4 ++-- conceptarium/conf/sweep.yaml | 10 ++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 2097b97b..3baca858 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -11,10 +11,10 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.10' diff --git a/conceptarium/conf/sweep.yaml b/conceptarium/conf/sweep.yaml index 0843f122..94d69655 100644 --- a/conceptarium/conf/sweep.yaml +++ b/conceptarium/conf/sweep.yaml @@ -9,12 +9,10 @@ hydra: # standard grid search params: seed: 42 - dataset: dag_asia, dag_sachs, dag_insurance - model: cbm, cem, c2bm + dataset: cub, cub_incomplete, awa2, awa2_incomplete + model: cbm, cem model.train_inference._target_: - torch_concepts.nn.DeterministicInference, - torch_concepts.nn.IndependentInference, - torch_concepts.nn.AncestralSamplingInference + torch_concepts.nn.DeterministicInference # --- inference params # +model.train_inference.detach: false, true # +model.train_inference.p: 0.4 @@ -45,7 +43,7 @@ trainer: # logger: wandb # log_model: true # whether to save checkpoint on wandb save_top_k: 1. # whether to save checkpoint locally - max_epochs: 200 + max_epochs: 20 patience: 20 matmul_precision: medium From 48bb834024adb8d42cd4457c62ee1651703b096c Mon Sep 17 00:00:00 2001 From: francescoTheSantis Date: Fri, 24 Apr 2026 16:37:03 +0000 Subject: [PATCH 14/16] rename download function + revert action upgrade + add task name into cub_incomplete concept subset Co-authored-by: Copilot --- .github/workflows/coverage.yml | 4 ++-- conceptarium/conf/dataset/cub_incomplete.yaml | 3 ++- torch_concepts/data/datasets/awa2.py | 4 ++-- torch_concepts/data/datasets/bnlearn.py | 4 ++-- torch_concepts/data/datasets/cub.py | 6 +++--- torch_concepts/data/io.py | 4 ++-- 6 files changed, 13 insertions(+), 12 deletions(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 3baca858..2097b97b 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -11,10 +11,10 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v4 with: python-version: '3.10' diff --git a/conceptarium/conf/dataset/cub_incomplete.yaml b/conceptarium/conf/dataset/cub_incomplete.yaml index b90e6685..540e3c2e 100644 --- a/conceptarium/conf/dataset/cub_incomplete.yaml +++ b/conceptarium/conf/dataset/cub_incomplete.yaml @@ -41,7 +41,8 @@ concept_subset: [ 'has_tail_pattern::multi-colored', 'has_bill_color::grey', 'has_bill_color::black', - 'has_bill_color::buff' + 'has_bill_color::buff', + 'class', # task label ] # Concept descriptions (optional; leave null to use raw attribute names) diff --git a/torch_concepts/data/datasets/awa2.py b/torch_concepts/data/datasets/awa2.py index 6b66bdda..bf137d9c 100644 --- a/torch_concepts/data/datasets/awa2.py +++ b/torch_concepts/data/datasets/awa2.py @@ -23,7 +23,7 @@ from torch_concepts import Annotations, AxisAnnotation from torch_concepts.data.base import ConceptDataset -from torch_concepts.data.io import download_file, zip_is_valid +from torch_concepts.data.io import download_url_wget, zip_is_valid logger = logging.getLogger(__name__) @@ -313,7 +313,7 @@ def download(self): for url in URLS: dest = os.path.join(self.root, url.split("/")[-1]) for attempt in range(1, _MAX_RETRIES + 1): - download_file(url, dest) + download_url_wget(url, dest) print(f" Verifying {os.path.basename(dest)} (attempt {attempt}/{_MAX_RETRIES}) ...") if zip_is_valid(dest): break diff --git a/torch_concepts/data/datasets/bnlearn.py b/torch_concepts/data/datasets/bnlearn.py index 1fe3b245..fc1ab7a6 100644 --- a/torch_concepts/data/datasets/bnlearn.py +++ b/torch_concepts/data/datasets/bnlearn.py @@ -14,7 +14,7 @@ from ..base import ConceptDataset from ..preprocessing.autoencoder import extract_embs_from_autoencoder -from ..io import download_urllib +from ..io import download_url BUILTIN_DAGS = ['asia', 'alarm', 'andes', 'sachs', 'water'] @@ -87,7 +87,7 @@ def download(self): pass else: url = f'https://www.bnlearn.com/bnrepository/{self.name}/{self.name}.bif.gz' - gz_path = download_urllib(url, self.root_dir) + gz_path = download_url(url, self.root_dir) bif_path = self.raw_paths[0] # Decompress .gz file diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index 5bcd9ec9..84c1322a 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -25,7 +25,7 @@ from torch_concepts import Annotations, AxisAnnotation from torch_concepts.data.base import ConceptDataset -from torch_concepts.data.io import download_urllib +from torch_concepts.data.io import download_url logger = logging.getLogger(__name__) @@ -785,9 +785,9 @@ def download(self) -> None: os.makedirs(class_attr_dir) for split_name in ('train', 'val', 'test'): url = f"{URLS[0]}/{split_name}.pkl" - download_urllib(url, class_attr_dir) + download_url(url, class_attr_dir) - tgz_path = download_urllib(URLS[1], self.root) + tgz_path = download_url(URLS[1], self.root) with tarfile.open(tgz_path, "r:gz") as tar: tar.extractall(path=self.root) diff --git a/torch_concepts/data/io.py b/torch_concepts/data/io.py index 63cc1cbe..57b58a07 100644 --- a/torch_concepts/data/io.py +++ b/torch_concepts/data/io.py @@ -103,7 +103,7 @@ def update_to(self, b=1, bsize=1, tsize=None): self.update(b * bsize - self.n) -def download_urllib(url: str, +def download_url(url: str, folder: str, filename: Optional[str] = None, verbose: bool = True): @@ -154,7 +154,7 @@ def wget_available() -> bool: return _shutil.which("wget") is not None -def download_file(url: str, dest: str) -> None: +def download_url_wget(url: str, dest: str) -> None: """Download *url* to *dest*. Uses ``wget --continue`` when available (handles large files and From a779ca63829934885f3f6eb952682e7edd0eb08b Mon Sep 17 00:00:00 2001 From: francescoTheSantis Date: Fri, 24 Apr 2026 16:37:53 +0000 Subject: [PATCH 15/16] return to default sweep --- conceptarium/conf/sweep.yaml | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/conceptarium/conf/sweep.yaml b/conceptarium/conf/sweep.yaml index 94d69655..440a4c03 100644 --- a/conceptarium/conf/sweep.yaml +++ b/conceptarium/conf/sweep.yaml @@ -9,10 +9,12 @@ hydra: # standard grid search params: seed: 42 - dataset: cub, cub_incomplete, awa2, awa2_incomplete - model: cbm, cem + dataset: dag_asia, dag_sachs + model: cbm, cem, c2bm model.train_inference._target_: - torch_concepts.nn.DeterministicInference + torch_concepts.nn.DeterministicInference, + torch_concepts.nn.IndependentInference, + torch_concepts.nn.AncestralSamplingInference # --- inference params # +model.train_inference.detach: false, true # +model.train_inference.p: 0.4 @@ -43,7 +45,7 @@ trainer: # logger: wandb # log_model: true # whether to save checkpoint on wandb save_top_k: 1. # whether to save checkpoint locally - max_epochs: 20 + max_epochs: 200 patience: 20 matmul_precision: medium From fc95fee612db3f78002d2dfdb9c7c3a33460a3e3 Mon Sep 17 00:00:00 2001 From: francescoTheSantis Date: Fri, 24 Apr 2026 16:49:25 +0000 Subject: [PATCH 16/16] add tests for .io functions --- tests/data/test_io.py | 87 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/tests/data/test_io.py b/tests/data/test_io.py index ddac3996..3a92d084 100644 --- a/tests/data/test_io.py +++ b/tests/data/test_io.py @@ -14,6 +14,10 @@ save_pickle, load_pickle, download_url, + download_url_wget, + zip_is_valid, + wget_available, + DownloadProgressBar, ) @@ -151,3 +155,86 @@ def test_download_custom_filename(self): # Verify assert os.path.exists(path) assert os.path.basename(path) == custom_name + + +class TestZipIsValid: + """Test zip file validation.""" + + def test_valid_zip(self): + """zip_is_valid returns True for a well-formed zip.""" + with tempfile.TemporaryDirectory() as tmpdir: + zip_path = os.path.join(tmpdir, "good.zip") + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr("hello.txt", "hello world") + assert zip_is_valid(zip_path) is True + + def test_invalid_zip_bad_file(self): + """zip_is_valid returns False for a file that is not a zip.""" + with tempfile.TemporaryDirectory() as tmpdir: + bad_path = os.path.join(tmpdir, "bad.zip") + with open(bad_path, 'wb') as f: + f.write(b"this is not a zip file at all") + assert zip_is_valid(bad_path) is False + + def test_invalid_zip_truncated(self): + """zip_is_valid returns False for a truncated/corrupt zip.""" + with tempfile.TemporaryDirectory() as tmpdir: + zip_path = os.path.join(tmpdir, "truncated.zip") + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr("data.txt", "some data") + # Corrupt it by truncating + with open(zip_path, 'r+b') as f: + f.truncate(10) + assert zip_is_valid(zip_path) is False + + +class TestWgetAvailable: + """Test wget availability detection.""" + + def test_returns_bool(self): + """wget_available always returns a bool.""" + result = wget_available() + assert isinstance(result, bool) + + +class TestDownloadUrlWget: + """Test download_url_wget.""" + + def test_download_creates_file(self): + """download_url_wget downloads a small file successfully.""" + with tempfile.TemporaryDirectory() as tmpdir: + url = "https://raw.githubusercontent.com/pytorch/pytorch/main/README.md" + dest = os.path.join(tmpdir, "README.md") + download_url_wget(url, dest) + assert os.path.exists(dest) + assert os.path.getsize(dest) > 0 + + def test_download_resume(self): + """download_url_wget does not overwrite a pre-existing file of the same name.""" + with tempfile.TemporaryDirectory() as tmpdir: + url = "https://raw.githubusercontent.com/pytorch/pytorch/main/README.md" + dest = os.path.join(tmpdir, "README.md") + # First download + download_url_wget(url, dest) + size_first = os.path.getsize(dest) + # Second download (resume / skip) + download_url_wget(url, dest) + size_second = os.path.getsize(dest) + assert size_second >= size_first + + +class TestDownloadProgressBar: + """Test DownloadProgressBar.update_to.""" + + def test_update_to_sets_total(self): + """update_to sets self.total when tsize is provided.""" + with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, + desc="test", disable=True) as bar: + bar.update_to(b=1, bsize=1, tsize=1024) + assert bar.total == 1024 + + def test_update_to_without_tsize(self): + """update_to works without tsize (no total set).""" + with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, + desc="test", disable=True) as bar: + bar.update_to(b=2, bsize=512) # should not raise