From 996767c636dc643262b04c79fcda29234c4ccbf0 Mon Sep 17 00:00:00 2001 From: Tuomas Date: Mon, 13 Apr 2026 16:05:19 -0700 Subject: [PATCH 1/2] added initial pseudo-label support --- .../loading-data/celeba_clip_with_backbone.py | 258 +++++++++++++++ torch_concepts/data/datasets/__init__.py | 3 + torch_concepts/data/datasets/celeba_clip.py | 305 ++++++++++++++++++ 3 files changed, 566 insertions(+) create mode 100644 examples/loading-data/celeba_clip_with_backbone.py create mode 100644 torch_concepts/data/datasets/celeba_clip.py diff --git a/examples/loading-data/celeba_clip_with_backbone.py b/examples/loading-data/celeba_clip_with_backbone.py new file mode 100644 index 00000000..e6eeadd7 --- /dev/null +++ b/examples/loading-data/celeba_clip_with_backbone.py @@ -0,0 +1,258 @@ +""" +CelebA Concept Bottleneck Model (Low-Level Interface) +====================================================== + +This example demonstrates how to: +1. Load the CelebA dataset using PyC's dataset utilities +2. Use a pretrained backbone (ResNet50) for feature extraction +3. Build a Concept Bottleneck Model using the low-level API +4. Train the model to predict facial attributes (concepts) and a target task + +Key Components: +- CelebADataset: PyC dataset wrapper for CelebA with concept annotations +- Backbone: Pretrained feature extractor (ResNet50, VGG, EfficientNet, DINOv2, etc.) +- LinearLatentToConcept: Maps latent embeddings to concept predictions +- LinearConceptToConcept: Maps concept predictions to task predictions + +Dataset: CelebA with 40 binary facial attributes +Task: Predict 'Attractive' attribute from other concept attributes +""" +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from sklearn.metrics import accuracy_score +from tqdm import tqdm + +from torch_concepts import seed_everything +from torch_concepts.data.datasets import CelebADataset, CelebACLIPDataset +from torch_concepts.data.backbone import Backbone +from torch_concepts.nn import LinearLatentToConcept, LinearConceptToConcept + + +def main(): + # ========================================================================= + # Configuration + # ========================================================================= + seed_everything(42) + + # Training hyperparameters + batch_size = 16 + n_epochs = 100 + learning_rate = 0.01 + concept_weight = 10 # Weight for concept loss + task_weight = 1 # Weight for task loss + + # Model configuration + backbone_name = 'resnet50' # Options: 'resnet18', 'resnet50', 'vgg16', 'efficientnet_b0', etc. + latent_dims = 256 # Dimension of latent space after backbone + + # Task configuration - which attribute to predict as the main task + task_attribute = 'Attractive' + + # Device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # ========================================================================= + # Load CelebA Dataset + # ========================================================================= + print("\n1. Loading CelebA dataset...") + + # CelebADataset will try to automatically download the raw data if not present + # in the root directory. If this fails, please manually download the required files + # ["img_align_celeba.zip", "list_attr_celeba.txt", "list_eval_partition.txt"] + # and place them in the target root directory. + # Note: CelebA is a large dataset (~1.4GB for images) + #dataset = CelebADataset(root='./data/celeba') + dataset = CelebACLIPDataset(root='./data/celeba') + + # Get annotations for concepts + annotations = dataset.annotations.get_axis_annotation(1) + + print(f" Dataset size: {len(dataset)} samples") + print(f" Number of concepts: {len(annotations.labels)}") + print(f" Task attribute: {task_attribute}") + + # Get concept and task indices from annotations + all_labels = dataset.annotations[1].labels + concept_indices = [all_labels.index(c) for c in all_labels if c != task_attribute] + task_index = all_labels.index(task_attribute) + + # ========================================================================= + # Initialize Backbone for Feature Extraction + # ========================================================================= + print(f"\n2. Loading backbone: {backbone_name}...") + + backbone = Backbone(name=backbone_name, device=device) + + # Freeze backbone parameters - we only train the CBM layers + for param in backbone.parameters(): + param.requires_grad = False + + # ========================================================================= + # Build Concept Bottleneck Model (Low-Level API) + # ========================================================================= + print("\n3. Building CBM architecture...") + + concept_dims = len(concept_indices) # all binary concepts + task_dims = 1 # Binary classification + + # Latent encoder: reduces backbone features to latent space + latent_encoder = nn.Sequential( + nn.Linear(backbone.out_features, latent_dims), + torch.nn.LeakyReLU(), + ) + + # Concept encoder: maps latent space to concept predictions + # Uses PyC's LinearLatentToConcept layer + concept_encoder = LinearLatentToConcept( + in_latent=latent_dims, + out_concepts=concept_dims + ) + + # Task predictor: maps concepts to task prediction + # Uses PyC's LinearConceptToConcept layer + task_predictor = LinearConceptToConcept( + in_concepts=concept_dims, + out_concepts=task_dims + ) + + # Combine into a ModuleDict for easy management + model = nn.ModuleDict({ + 'backbone': backbone, + 'latent_encoder': latent_encoder, + 'concept_encoder': concept_encoder, + 'task_predictor': task_predictor, + }).to(device) + + print(f" Latent dims: {latent_dims}") + print(f" Concept dims: {concept_dims}") + print(f" Task dims: {task_dims}") + + # ========================================================================= + # Create DataLoader + # ========================================================================= + print("\n4. Creating DataLoader...") + + # Use a smaller subset for this example to speed up training + max_samples = 100 + subset_indices = list(range(min(max_samples, len(dataset)))) + subset = torch.utils.data.Subset(dataset, subset_indices) + + dataloader = DataLoader( + subset, + batch_size=batch_size, + shuffle=True, + num_workers=0, # Set to 0 for debugging; increase for production + pin_memory=True if device.type == 'cuda' else False, + ) + + print(f" Subset size: {len(subset)} samples") + print(f" Batches per epoch: {len(dataloader)}") + + # ========================================================================= + # Training Loop + # ========================================================================= + print("\n5. Training CBM...") + + # Only optimize parameters that require gradients (excludes frozen backbone) + trainable_params = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate) + loss_fn = nn.BCEWithLogitsLoss() + + model.train() + for epoch in range(n_epochs): + epoch_concept_loss = 0.0 + epoch_task_loss = 0.0 + all_concept_preds = [] + all_concept_targets = [] + all_task_preds = [] + all_task_targets = [] + + progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{n_epochs}") + for batch in progress_bar: + # Extract inputs and targets from batch + x = batch['inputs']['x'].to(device) # Images: (B, C, H, W) + c = batch['concepts']['c'].to(device) # All concepts: (B, n_concepts) + + # Separate concept targets and task target + c_targets = c[:, concept_indices].float() # Concept targets + y_targets = c[:, task_index:task_index+1].float() # Task target + + optimizer.zero_grad() + + # Forward pass through CBM + # 1. Backbone extracts visual features + features = model['backbone'](x) # (B, backbone_out_features) + + # 2. Latent encoder compresses features + latent = model['latent_encoder'](features) # (B, latent_dims) + + # 3. Concept encoder predicts concepts + c_pred = model['concept_encoder'](latent=latent) # (B, concept_dims) + + # 4. Task predictor predicts task from concepts + y_pred = model['task_predictor'](concepts=c_pred) # (B, task_dims) + + # Compute losses + concept_loss = loss_fn(c_pred, c_targets) + task_loss = loss_fn(y_pred, y_targets) + total_loss = concept_weight * concept_loss + task_weight * task_loss + + # Backward pass + total_loss.backward() + optimizer.step() + + # Track metrics + epoch_concept_loss += concept_loss.item() + epoch_task_loss += task_loss.item() + + all_concept_preds.append((c_pred.detach() > 0).cpu()) + all_concept_targets.append(c_targets.cpu()) + all_task_preds.append((y_pred.detach() > 0).cpu()) + all_task_targets.append(y_targets.cpu()) + + progress_bar.set_postfix({ + 'c_loss': f'{concept_loss.item():.3f}', + 't_loss': f'{task_loss.item():.3f}' + }) + + # Compute epoch metrics + all_concept_preds = torch.cat(all_concept_preds, dim=0) + all_concept_targets = torch.cat(all_concept_targets, dim=0) + all_task_preds = torch.cat(all_task_preds, dim=0) + all_task_targets = torch.cat(all_task_targets, dim=0) + + concept_acc = accuracy_score( + (all_concept_targets >= 0.5).float().numpy().flatten(), + all_concept_preds.numpy().flatten() + ) + task_acc = accuracy_score( + (all_task_targets >= 0.5).float().numpy().flatten(), + all_task_preds.numpy().flatten() + ) + + avg_concept_loss = epoch_concept_loss / len(dataloader) + avg_task_loss = epoch_task_loss / len(dataloader) + + print(f"\nEpoch {epoch+1} Summary:") + print(f" Concept Loss: {avg_concept_loss:.4f} | Concept Acc: {concept_acc:.4f}") + print(f" Task Loss: {avg_task_loss:.4f} | Task Acc: {task_acc:.4f}") + + # ========================================================================= + # Summary + # ========================================================================= + print("\n" + "="*60) + print("Training Complete!") + print("="*60) + print(f"\nThis example demonstrated:") + print(f" 1. Loading CelebA dataset") + print(f" 2. Using {backbone_name} backbone for feature extraction") + print(f" 3. Building a CBM with low-level PyC layers:") + print(f" - LinearLatentToConcept: {latent_dims} -> {concept_dims}") + print(f" - LinearConceptToConcept: {concept_dims} -> {task_dims}") + print(f" 4. Training to predict '{task_attribute}' from intermediate concepts") + + +if __name__ == "__main__": + main() diff --git a/torch_concepts/data/datasets/__init__.py b/torch_concepts/data/datasets/__init__.py index 9e23a7d3..50bad460 100644 --- a/torch_concepts/data/datasets/__init__.py +++ b/torch_concepts/data/datasets/__init__.py @@ -2,6 +2,7 @@ from .toy import ToyDataset, CompletenessDataset from .categorical_toy_dag import ToyDAGDataset from .celeba import CelebADataset +from .celeba_clip import CelebACLIPDataset, DEFAULT_CLIP_CONCEPT_PROMPTS __all__: list[str] = [ "BnLearnDataset", @@ -10,5 +11,7 @@ "ToyFunctionDAGDataset", "CompletenessDataset", "CelebADataset", + "CelebACLIPDataset", + "DEFAULT_CLIP_CONCEPT_PROMPTS", ] diff --git a/torch_concepts/data/datasets/celeba_clip.py b/torch_concepts/data/datasets/celeba_clip.py new file mode 100644 index 00000000..eb7e0dda --- /dev/null +++ b/torch_concepts/data/datasets/celeba_clip.py @@ -0,0 +1,305 @@ +import os +import re +import logging +import torch +import pandas as pd +import numpy as np +from typing import Dict, List, Optional, Union +from PIL import Image +from tqdm import tqdm + +try: + import open_clip +except ImportError as e: + raise ImportError( + "open_clip is required for CelebACLIPDataset. " + "Install it with: pip install open-clip-torch" + ) from e + +from torch_concepts import Annotations, AxisAnnotation +from .celeba import CelebADataset + +logger = logging.getLogger(__name__) + + +# Default prompts mirror the 40 CelebA attributes so the class works +# as a drop-in replacement with no configuration required. +DEFAULT_CLIP_CONCEPT_PROMPTS: Dict[str, str] = { + "5_o_Clock_Shadow": "a photo of a person with 5 o'clock shadow stubble", + "Arched_Eyebrows": "a photo of a person with arched eyebrows", + "Attractive": "a photo of an attractive person", + "Bags_Under_Eyes": "a photo of a person with bags under their eyes", + "Bald": "a photo of a bald person", + "Bangs": "a photo of a person with bangs", + "Big_Lips": "a photo of a person with big lips", + "Big_Nose": "a photo of a person with a big nose", + "Black_Hair": "a photo of a person with black hair", + "Blond_Hair": "a photo of a person with blond hair", + "Blurry": "a blurry photo of a person", + "Brown_Hair": "a photo of a person with brown hair", + "Bushy_Eyebrows": "a photo of a person with bushy eyebrows", + "Chubby": "a photo of a chubby person", + "Double_Chin": "a photo of a person with a double chin", + "Eyeglasses": "a photo of a person wearing eyeglasses", + "Goatee": "a photo of a person with a goatee", + "Gray_Hair": "a photo of a person with gray hair", + "Heavy_Makeup": "a photo of a person wearing heavy makeup", + "High_Cheekbones": "a photo of a person with high cheekbones", + "Male": "a photo of a man", + "Mouth_Slightly_Open": "a photo of a person with their mouth slightly open", + "Mustache": "a photo of a person with a mustache", + "Narrow_Eyes": "a photo of a person with narrow eyes", + "No_Beard": "a photo of a person with no beard", + "Oval_Face": "a photo of a person with an oval face", + "Pale_Skin": "a photo of a person with pale skin", + "Pointy_Nose": "a photo of a person with a pointy nose", + "Receding_Hairline": "a photo of a person with a receding hairline", + "Rosy_Cheeks": "a photo of a person with rosy cheeks", + "Sideburns": "a photo of a person with sideburns", + "Smiling": "a photo of a smiling person", + "Straight_Hair": "a photo of a person with straight hair", + "Wavy_Hair": "a photo of a person with wavy hair", + "Wearing_Earrings": "a photo of a person wearing earrings", + "Wearing_Hat": "a photo of a person wearing a hat", + "Wearing_Lipstick": "a photo of a person wearing lipstick", + "Wearing_Necklace": "a photo of a person wearing a necklace", + "Wearing_Necktie": "a photo of a person wearing a necktie", + "Young": "a photo of a young person", +} + + +class CelebACLIPDataset(CelebADataset): + """CelebA dataset with CLIP-generated concept pseudo-labels. + + Replaces the 40 hand-annotated CelebA binary attributes with binary + pseudo-labels derived by thresholding cosine similarities between image + embeddings and user-supplied text prompts, computed with an + ``open_clip`` model. + + The CelebA images are downloaded and cached exactly as in + :class:`CelebADataset`. The CLIP pseudo-labels are computed once and + stored next to the other processed files; subsequent instantiations load + them from disk without re-running inference. + + Args: + root: Root directory for the dataset. Defaults to + ``/data/celeba``. + concept_prompts: Concept vocabulary as either + + * a ``dict`` mapping concept name → text prompt, or + * a ``list`` of text prompts (concept names will be the prompt + strings themselves). + + Defaults to :data:`DEFAULT_CLIP_CONCEPT_PROMPTS`, which mirrors + all 40 CelebA attributes. + clip_model: ``open_clip`` model name. + Default: ``'ViT-SO400M-14-SigLIP2-384'`` (SigLIP2). + clip_pretrained: ``open_clip`` pretrained weights tag. + Default: ``'webli'``. + clip_device: Device used for CLIP inference. Defaults to CUDA when + available, otherwise CPU. + inference_batch_size: Number of images processed per CLIP forward + pass. Default: ``64``. + concept_subset: Optional list of concept names to retain after + pseudo-label generation. + label_descriptions: Optional dict mapping concept names to + human-readable descriptions (metadata only, not used in training). + + Example:: + + from torch_concepts.data.datasets import CelebACLIPDataset + + # Drop-in replacement — uses default SigLIP2 prompts for all 40 attrs + dataset = CelebACLIPDataset(root='./data/celeba') + + # Custom concept vocabulary + dataset = CelebACLIPDataset( + root='./data/celeba', + concept_prompts={ + 'smiling': 'a photo of a smiling person', + 'blonde': 'a photo of a person with blond hair', + 'glasses': 'a photo of a person wearing glasses', + }, + ) + """ + + def __init__( + self, + root: Optional[str] = None, + concept_prompts: Optional[Union[Dict[str, str], List[str]]] = None, + clip_model: str = 'ViT-B-16-SigLIP2', + clip_pretrained: str = 'webli', + clip_device: Optional[str] = None, + threshold: float = 0.0, + inference_batch_size: int = 64, + concept_subset: Optional[List[str]] = None, + label_descriptions: Optional[dict] = None, + ): + # Normalise concept_prompts to a dict before super().__init__ is + # called, because __init__ triggers load() → build() → CLIP inference. + if concept_prompts is None: + concept_prompts = DEFAULT_CLIP_CONCEPT_PROMPTS + elif isinstance(concept_prompts, list): + concept_prompts = {p: p for p in concept_prompts} + + # Store all CLIP-specific state on self *before* calling super so that + # overridden processed_filenames / build / load_raw can access them. + self._concept_prompts: Dict[str, str] = concept_prompts + self._clip_model_name: str = clip_model + self._clip_pretrained: str = clip_pretrained + self._clip_device: str = clip_device or ('cuda' if torch.cuda.is_available() else 'cpu') + self._threshold: float = threshold + self._inference_batch_size: int = inference_batch_size + + super().__init__( + root=root, + concept_subset=concept_subset, + label_descriptions=label_descriptions, + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @property + def _model_key(self) -> str: + """Filesystem-safe identifier for the current model + pretrained tag.""" + raw = f"{self._clip_model_name}_{self._clip_pretrained}" + return re.sub(r'[^a-zA-Z0-9_]', '_', raw) + + # ------------------------------------------------------------------ + # ConceptDataset interface overrides + # ------------------------------------------------------------------ + + @property + def processed_filenames(self) -> List[str]: + """Processed files: parent's four files plus two CLIP-specific ones.""" + return [ + "filenames.txt", # [0] shared + "concepts.h5", # [1] parent annotated concepts + "annotations.pt", # [2] parent annotations + "split_mapping.h5", # [3] split labels + f"clip_concepts_{self._model_key}.h5", # [4] CLIP pseudo-labels + f"clip_annotations_{self._model_key}.pt", # [5] CLIP annotations + ] + + def build(self): + """Build processed files: parent images/splits then CLIP pseudo-labels.""" + # Ensure raw CelebA files are downloaded, extracted, and the parent's + # four processed files (filenames, annotated concepts, annotations, + # splits) are written to disk. + super().build() + + clip_concepts_path = self.processed_paths[4] + clip_annotations_path = self.processed_paths[5] + + if os.path.exists(clip_concepts_path) and os.path.exists(clip_annotations_path): + logger.info("CLIP pseudo-labels already exist, skipping inference.") + return + + self._compute_clip_pseudo_labels(clip_concepts_path, clip_annotations_path) + + def load_raw(self): + """Load filenames and CLIP pseudo-labels from processed files.""" + self.maybe_build() + + logger.info(f"Loading CelebACLIPDataset from {self.root_dir}") + + with open(self.processed_paths[0], 'r') as f: + filenames = f.read().strip().split('\n') + + concepts = pd.read_hdf(self.processed_paths[4], "concepts") + annotations = torch.load(self.processed_paths[5], weights_only=False) + + return filenames, concepts, annotations, None + + # ------------------------------------------------------------------ + # CLIP pseudo-label computation + # ------------------------------------------------------------------ + + def _compute_clip_pseudo_labels( + self, + concepts_out_path: str, + annotations_out_path: str, + ) -> None: + """Run CLIP inference over all images and save binary pseudo-labels. + + Args: + concepts_out_path: Destination HDF5 file for the concept tensor. + annotations_out_path: Destination ``.pt`` file for the + :class:`Annotations` object. + """ + + device = torch.device(self._clip_device) + logger.info( + f"Loading CLIP model '{self._clip_model_name}' " + f"(pretrained='{self._clip_pretrained}') on {device} …" + ) + + model, _, preprocess = open_clip.create_model_and_transforms( + self._clip_model_name, + pretrained=self._clip_pretrained, + device=device, + ) + tokenizer = open_clip.get_tokenizer(self._clip_model_name) + model.eval() + + concept_names = list(self._concept_prompts.keys()) + prompts = list(self._concept_prompts.values()) + + # Encode text prompts once + logger.info(f"Encoding {len(prompts)} concept text prompts …") + with torch.no_grad(): + text_tokens = tokenizer(prompts).to(device) + text_features = model.encode_text(text_tokens) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # Load filenames from parent's processed file + with open(self.processed_paths[0], 'r') as f: + filenames = f.read().strip().split('\n') + + n_samples = len(filenames) + n_concepts = len(concept_names) + pseudo_labels = torch.zeros(n_samples, n_concepts, dtype=torch.float32) + + img_dir = os.path.join(self.root, "raw", "img_align_celeba") + batch_size = self._inference_batch_size + + logger.info( + f"Running CLIP inference on {n_samples} images " + f"(batch_size={batch_size}, threshold={self._threshold}) …" + ) + + for start in tqdm(range(0, n_samples, batch_size), desc="CLIP pseudo-labels"): + end = min(start + batch_size, n_samples) + batch_imgs = [] + for fname in filenames[start:end]: + img = Image.open(os.path.join(img_dir, fname)).convert("RGB") + batch_imgs.append(preprocess(img)) + + batch_tensor = torch.stack(batch_imgs).to(device) + + with torch.no_grad(): + img_features = model.encode_image(batch_tensor) + img_features = img_features / img_features.norm(dim=-1, keepdim=True) + # Cosine similarity: (B, n_concepts) + logits = img_features @ text_features.T * model.logit_scale.exp() + model.logit_bias + probs = torch.sigmoid(logits) + + pseudo_labels[start:end] = probs.cpu() + + # Save as DataFrame so the parent's set_concepts() path (which expects + # a DataFrame with named columns) works without modification. + concepts_df = pd.DataFrame(pseudo_labels.numpy(), columns=concept_names) + concepts_df.to_hdf(concepts_out_path, key="concepts", mode="w") + + annotations = Annotations({ + 1: AxisAnnotation( + labels=concept_names, + cardinalities=tuple([1] * n_concepts), + metadata={name: {'type': 'discrete'} for name in concept_names}, + ) + }) + torch.save(annotations, annotations_out_path) + + logger.info(f"Saved CLIP pseudo-labels to {concepts_out_path}") From ca9d97c05b902eb01b3ac1690ff2ef906b512b67 Mon Sep 17 00:00:00 2001 From: Tuomas Date: Mon, 13 Apr 2026 19:10:34 -0700 Subject: [PATCH 2/2] updated example to use good sigmoid scaling --- .../loading-data/celeba_clip_with_backbone.py | 112 ++++++++++++------ torch_concepts/data/datasets/celeba_clip.py | 17 ++- 2 files changed, 90 insertions(+), 39 deletions(-) diff --git a/examples/loading-data/celeba_clip_with_backbone.py b/examples/loading-data/celeba_clip_with_backbone.py index e6eeadd7..3a936d2b 100644 --- a/examples/loading-data/celeba_clip_with_backbone.py +++ b/examples/loading-data/celeba_clip_with_backbone.py @@ -1,25 +1,28 @@ """ -CelebA Concept Bottleneck Model (Low-Level Interface) -====================================================== +CelebA Concept Bottleneck Model with CLIP Pseudo-Labels (Low-Level Interface) +============================================================================== This example demonstrates how to: 1. Load the CelebA dataset using PyC's dataset utilities -2. Use a pretrained backbone (ResNet50) for feature extraction -3. Build a Concept Bottleneck Model using the low-level API -4. Train the model to predict facial attributes (concepts) and a target task +2. Use CLIP pseudo-labels (SigLIP2) for concept supervision +3. Use ground-truth CelebA annotations for task supervision only +4. Use a pretrained backbone (ResNet50) for feature extraction +5. Build a Concept Bottleneck Model using the low-level API Key Components: -- CelebADataset: PyC dataset wrapper for CelebA with concept annotations +- CelebACLIPDataset: CLIP-derived pseudo-labels for concept supervision +- CelebADataset: Ground-truth annotations used only for the task label - Backbone: Pretrained feature extractor (ResNet50, VGG, EfficientNet, DINOv2, etc.) - LinearLatentToConcept: Maps latent embeddings to concept predictions - LinearConceptToConcept: Maps concept predictions to task predictions Dataset: CelebA with 40 binary facial attributes -Task: Predict 'Attractive' attribute from other concept attributes +Concept supervision: CLIP SigLIP2 pseudo-labels (no human annotations required) +Task supervision: Ground-truth 'Attractive' attribute from CelebA annotations """ import torch import torch.nn as nn -from torch.utils.data import DataLoader +from torch.utils.data import Dataset, DataLoader from sklearn.metrics import accuracy_score from tqdm import tqdm @@ -29,6 +32,32 @@ from torch_concepts.nn import LinearLatentToConcept, LinearConceptToConcept +class HybridConceptDataset(Dataset): + """Pairs a CLIP pseudo-label dataset with a ground-truth dataset. + + Returns images and CLIP concept pseudo-labels from ``clip_dataset``, + and appends the ground-truth concept tensor from ``gt_dataset`` under + the key ``'gt_concepts'``. The two datasets must be aligned + (same images in the same order). + """ + + def __init__(self, clip_dataset: CelebACLIPDataset, gt_dataset: CelebADataset): + self.clip_dataset = clip_dataset + self.gt_dataset = gt_dataset + + def __len__(self): + return len(self.clip_dataset) + + def __getitem__(self, idx): + clip_sample = self.clip_dataset[idx] + gt_sample = self.gt_dataset[idx] + return { + 'inputs': clip_sample['inputs'], # images from CLIP dataset + 'concepts': clip_sample['concepts'], # CLIP pseudo-labels + 'gt_concepts': gt_sample['concepts'], # ground-truth annotations + } + + def main(): # ========================================================================= # Configuration @@ -54,29 +83,34 @@ def main(): print(f"Using device: {device}") # ========================================================================= - # Load CelebA Dataset + # Load Datasets # ========================================================================= - print("\n1. Loading CelebA dataset...") - - # CelebADataset will try to automatically download the raw data if not present - # in the root directory. If this fails, please manually download the required files - # ["img_align_celeba.zip", "list_attr_celeba.txt", "list_eval_partition.txt"] - # and place them in the target root directory. - # Note: CelebA is a large dataset (~1.4GB for images) - #dataset = CelebADataset(root='./data/celeba') - dataset = CelebACLIPDataset(root='./data/celeba') + print("\n1. Loading datasets...") - # Get annotations for concepts - annotations = dataset.annotations.get_axis_annotation(1) - - print(f" Dataset size: {len(dataset)} samples") - print(f" Number of concepts: {len(annotations.labels)}") - print(f" Task attribute: {task_attribute}") - - # Get concept and task indices from annotations - all_labels = dataset.annotations[1].labels - concept_indices = [all_labels.index(c) for c in all_labels if c != task_attribute] - task_index = all_labels.index(task_attribute) + # CelebADataset / CelebACLIPDataset will try to automatically download the + # raw data if not present in the root directory. If this fails, manually + # place ["img_align_celeba.zip", "list_attr_celeba.txt", + # "list_eval_partition.txt"] in ./data/celeba/raw/. + # Note: CelebA is a large dataset (~1.4GB for images). + + # CLIP dataset: provides pseudo-labels used as concept supervision signal. + clip_dataset = CelebACLIPDataset(root='./data/celeba') + + # Ground-truth dataset: used only to read the task label (Attractive). + # Both datasets share the same root so images are not re-downloaded. + gt_dataset = CelebADataset(root='./data/celeba') + + # Concept indices come from the CLIP dataset's vocabulary. + clip_labels = clip_dataset.annotations[1].labels + concept_indices = [clip_labels.index(c) for c in clip_labels if c != task_attribute] + + # Task index is looked up in the ground-truth dataset's annotations. + gt_labels = gt_dataset.annotations[1].labels + task_index = gt_labels.index(task_attribute) + + print(f" Dataset size: {len(clip_dataset)} samples") + print(f" CLIP concept dims: {len(concept_indices)}") + print(f" Task attribute (GT): {task_attribute}") # ========================================================================= # Initialize Backbone for Feature Extraction @@ -136,8 +170,9 @@ def main(): # Use a smaller subset for this example to speed up training max_samples = 100 - subset_indices = list(range(min(max_samples, len(dataset)))) - subset = torch.utils.data.Subset(dataset, subset_indices) + hybrid_dataset = HybridConceptDataset(clip_dataset, gt_dataset) + subset_indices = list(range(min(max_samples, len(hybrid_dataset)))) + subset = torch.utils.data.Subset(hybrid_dataset, subset_indices) dataloader = DataLoader( subset, @@ -172,12 +207,14 @@ def main(): progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{n_epochs}") for batch in progress_bar: # Extract inputs and targets from batch - x = batch['inputs']['x'].to(device) # Images: (B, C, H, W) - c = batch['concepts']['c'].to(device) # All concepts: (B, n_concepts) - - # Separate concept targets and task target - c_targets = c[:, concept_indices].float() # Concept targets - y_targets = c[:, task_index:task_index+1].float() # Task target + x = batch['inputs']['x'].to(device) # Images: (B, C, H, W) + c_clip = batch['concepts']['c'].to(device) # CLIP pseudo-labels: (B, n_concepts) + c_gt = batch['gt_concepts']['c'].to(device) # GT annotations: (B, n_concepts) + + # Concept supervision: CLIP pseudo-labels (excludes task attribute) + c_targets = c_clip[:, concept_indices].float() + # Task supervision: ground-truth annotation only + y_targets = c_gt[:, task_index:task_index+1].float() optimizer.zero_grad() @@ -250,6 +287,7 @@ def main(): print(f" 2. Using {backbone_name} backbone for feature extraction") print(f" 3. Building a CBM with low-level PyC layers:") print(f" - LinearLatentToConcept: {latent_dims} -> {concept_dims}") + print(f" - Using SigLIP2 pseudo-labels for learning concept layer") print(f" - LinearConceptToConcept: {concept_dims} -> {task_dims}") print(f" 4. Training to predict '{task_attribute}' from intermediate concepts") diff --git a/torch_concepts/data/datasets/celeba_clip.py b/torch_concepts/data/datasets/celeba_clip.py index eb7e0dda..dc267f30 100644 --- a/torch_concepts/data/datasets/celeba_clip.py +++ b/torch_concepts/data/datasets/celeba_clip.py @@ -67,6 +67,11 @@ "Young": "a photo of a young person", } +SIGMOID_TEMP_PARAMS = {"ViT-SO400M-14-SigLIP-384":{"t":58, "b":-7.54}, + "ViT-L-16-SigLIP-384":{"t":60, "b":-5.4}, + "ViT-L-14-336":{"t":58, "b":-14.5}, + "ViT-B-16-SigLIP2":{"t":64, "b":-8.32}} + class CelebACLIPDataset(CelebADataset): """CelebA dataset with CLIP-generated concept pseudo-labels. @@ -244,6 +249,15 @@ def _compute_clip_pseudo_labels( tokenizer = open_clip.get_tokenizer(self._clip_model_name) model.eval() + # Use temperature parameters tuned on imagenet classes + superclasses + # SigLIP default params tend to result in 0 activations for almost all concepts + if self._clip_model_name in SIGMOID_TEMP_PARAMS: + sigmoid_t = SIGMOID_TEMP_PARAMS[self._clip_model_name]["t"] + sigmoid_b = SIGMOID_TEMP_PARAMS[self._clip_model_name]["b"] + else: + sigmoid_t = model.logit_scale.exp() + sigmoid_b = model.logit_bias + concept_names = list(self._concept_prompts.keys()) prompts = list(self._concept_prompts.values()) @@ -283,9 +297,8 @@ def _compute_clip_pseudo_labels( img_features = model.encode_image(batch_tensor) img_features = img_features / img_features.norm(dim=-1, keepdim=True) # Cosine similarity: (B, n_concepts) - logits = img_features @ text_features.T * model.logit_scale.exp() + model.logit_bias + logits = img_features @ text_features.T * sigmoid_t + sigmoid_b probs = torch.sigmoid(logits) - pseudo_labels[start:end] = probs.cpu() # Save as DataFrame so the parent's set_concepts() path (which expects