From 61912c2b95977974f2cf9ca4b43c0f767b9d16de Mon Sep 17 00:00:00 2001 From: rganeshk Date: Sat, 26 Apr 2025 23:22:05 +0000 Subject: [PATCH 1/5] implemented rope --- ldm/modules/encoders/modules.py | 46 ++++++++++++++++++- ldm/modules/rope_utils.py | 20 ++++++++ scripts/train_clip_rope.py | 81 +++++++++++++++++++++++++++++++++ 3 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 ldm/modules/rope_utils.py create mode 100644 scripts/train_clip_rope.py diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index ededbe43e..8cc27f1df 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -1,10 +1,12 @@ import torch import torch.nn as nn from functools import partial -import clip +import open_clip as clip from einops import rearrange, repeat from transformers import CLIPTokenizer, CLIPTextModel import kornia +from ldm.modules.rope_utils import build_rope_cache, apply_rope + from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test @@ -140,10 +142,17 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_l super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) + # === Inject RoPE into attention layers === + for name, module in self.transformer.named_modules(): + if isinstance(module, torch.nn.MultiheadAttention): + setattr(self.transformer, name, RoPEAttentionWrapper(module)) + print(f"[RoPE] Wrapped attention module: {name}") + self.device = device self.max_length = max_length self.freeze() + def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): @@ -227,6 +236,41 @@ def forward(self, x): # x is assumed to be in range [-1,1] return self.model.encode_image(self.preprocess(x)) +class RoPEAttentionWrapper(nn.Module): + def __init__(self, attn_layer): + super().__init__() + self.attn = attn_layer + self.rope_cache = None + + def forward(self, x, *args, **kwargs): + B, S, C = x.shape # batch, seq_len, channels + device = x.device + num_heads = self.attn.num_heads + head_dim = C // num_heads + + # Linear projection to get QKV + qkv = F.linear(x, self.attn.in_proj_weight, self.attn.in_proj_bias) + qkv = qkv.view(B, S, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + # Build rope cache if not existing + if self.rope_cache is None or self.rope_cache[0].shape[2] != S: + self.rope_cache = build_rope_cache(S, head_dim, device) + + # Apply RoPE + q = apply_rope(q, self.rope_cache) + k = apply_rope(k, self.rope_cache) + + # Attention calculation + attn_weights = torch.matmul(q, k.transpose(-2, -1)) * (head_dim ** -0.5) + attn_weights = attn_weights.softmax(dim=-1) + attn_output = torch.matmul(attn_weights, v) + + attn_output = attn_output.transpose(1, 2).reshape(B, S, C) + output = self.attn.out_proj(attn_output) + + return output + if __name__ == "__main__": from ldm.util import count_params diff --git a/ldm/modules/rope_utils.py b/ldm/modules/rope_utils.py new file mode 100644 index 000000000..d15263cf1 --- /dev/null +++ b/ldm/modules/rope_utils.py @@ -0,0 +1,20 @@ +# ldm/modules/rope_utils.py + +import torch + +def build_rope_cache(seq_len, head_dim, device): + inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim)) + t = torch.arange(seq_len, device=device).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) # (seq_len, head_dim/2) + emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, head_dim) + sin_emb = emb.sin()[None, None, :, :] # (1, 1, seq_len, head_dim) + cos_emb = emb.cos()[None, None, :, :] + return sin_emb, cos_emb + +def apply_rope(x, rope_cache): + sin_emb, cos_emb = rope_cache + x1 = x[..., ::2] + x2 = x[..., 1::2] + x_out = torch.cat([x1 * cos_emb - x2 * sin_emb, + x1 * sin_emb + x2 * cos_emb], dim=-1) + return x_out diff --git a/scripts/train_clip_rope.py b/scripts/train_clip_rope.py new file mode 100644 index 000000000..481ec2558 --- /dev/null +++ b/scripts/train_clip_rope.py @@ -0,0 +1,81 @@ +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from datasets import load_dataset + +from ldm.modules.encoders.modules import FrozenCLIPEmbedder + +# === Config === +device = "cuda" if torch.cuda.is_available() else "cpu" +batch_size = 32 +epochs = 3 +lr = 1e-5 +max_length = 77 +save_path = "./clip_rope_finetuned.pth" + +# === Dataset === +class CocoCountingDataset(torch.utils.data.Dataset): + def __init__(self, split="train", tokenizer=None, max_length=77): + self.dataset = load_dataset("conceptual_captions", split=split) + self.tokenizer = tokenizer + self.max_length = max_length + self.number_words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten'] + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + caption = self.dataset[idx]['caption'].lower() + + if not any(word in caption for word in self.number_words): + caption = "one object." # fallback dummy caption + + encoding = self.tokenizer(caption, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt") + input_ids = encoding["input_ids"].squeeze(0) + attention_mask = encoding["attention_mask"].squeeze(0) + return input_ids, attention_mask + +# === Model === +model = FrozenCLIPEmbedder(version="openai/clip-vit-large-patch14", device=device, max_length=max_length) + +# ❗ Unfreeze only transformer parameters +for param in model.transformer.parameters(): + param.requires_grad = True + +model = model.to(device) + +# ❗ Optimizer on transformer parameters +optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.transformer.parameters()), lr=lr) + +# === Dataloader === +dataset = CocoCountingDataset(split="train", tokenizer=model.tokenizer, max_length=max_length) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) + +# === Training === +model.train() +for epoch in range(epochs): + total_loss = 0 + for input_ids, attention_mask in tqdm(dataloader): + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + + outputs = model.transformer(input_ids=input_ids, attention_mask=attention_mask) + embeddings = outputs.last_hidden_state + + # Simple L2 loss + loss = torch.mean(torch.norm(embeddings, dim=-1)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + + print(f"Epoch {epoch+1}/{epochs}: Loss={total_loss/len(dataloader):.4f}") + +# === Save the fine-tuned transformer +torch.save(model.transformer.state_dict(), save_path) +print(f"Fine-tuned text encoder saved to {save_path}") From 0afc84a1300bf6eb15c469867c23fc1cff1c95ef Mon Sep 17 00:00:00 2001 From: rganeshk Date: Sat, 26 Apr 2025 23:41:29 +0000 Subject: [PATCH 2/5] added finetune script --- scripts/finetune_encoder.py | 109 ++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 scripts/finetune_encoder.py diff --git a/scripts/finetune_encoder.py b/scripts/finetune_encoder.py new file mode 100644 index 000000000..fe2fec130 --- /dev/null +++ b/scripts/finetune_encoder.py @@ -0,0 +1,109 @@ +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from datasets import load_dataset +from sklearn.metrics import precision_recall_fscore_support +import torch.nn.functional as F + +from ldm.modules.encoders.modules import FrozenCLIPEmbedder + +# === Config === +device = "cuda" if torch.cuda.is_available() else "cpu" +batch_size = 32 +epochs = 3 +lr = 1e-5 +max_length = 77 +save_dir = "./checkpoints" +os.makedirs(save_dir, exist_ok=True) +save_every_n_steps = 1000 # Save every 1000 batches + +# === Dataset === +class CocoCountingDataset(torch.utils.data.Dataset): + def __init__(self, split="train", tokenizer=None, max_length=77): + self.dataset = load_dataset("conceptual_captions", split=split) + self.tokenizer = tokenizer + self.max_length = max_length + self.number_words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten'] + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + caption = self.dataset[idx]['caption'].lower() + label = int(any(word in caption for word in self.number_words)) # label 1 if counting word exists + + if label == 0: + caption = "one object." + + encoding = self.tokenizer(caption, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt") + input_ids = encoding["input_ids"].squeeze(0) + attention_mask = encoding["attention_mask"].squeeze(0) + return input_ids, attention_mask, label + +# === Model === +model = FrozenCLIPEmbedder(version="openai/clip-vit-large-patch14", device=device, max_length=max_length) + +for param in model.transformer.parameters(): + param.requires_grad = True + +model = model.to(device) +optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.transformer.parameters()), lr=lr) + +# === Dataloader === +dataset = CocoCountingDataset(split="train", tokenizer=model.tokenizer, max_length=max_length) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) + +# === Training === +model.train() +global_step = 0 +for epoch in range(epochs): + total_loss = 0 + preds, targets = [], [] + + for batch_idx, (input_ids, attention_mask, labels) in enumerate(tqdm(dataloader)): + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + labels = labels.to(device) + + outputs = model.transformer(input_ids=input_ids, attention_mask=attention_mask) + embeddings = outputs.last_hidden_state + + loss = torch.mean(torch.norm(embeddings, dim=-1)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + + # Mock "classification" for precision/recall: use embedding norm as pseudo-score + scores = torch.norm(embeddings[:, 0, :], dim=-1) # CLS token norm + pred_labels = (scores > scores.mean()).long() + + preds.extend(pred_labels.cpu().tolist()) + targets.extend(labels.cpu().tolist()) + + global_step += 1 + + # === Save checkpoint mid-epoch + if global_step % save_every_n_steps == 0: + checkpoint_path = os.path.join(save_dir, f"clip_rope_step{global_step}.pth") + torch.save(model.transformer.state_dict(), checkpoint_path) + print(f"[Checkpoint] Saved at step {global_step}") + + # === End of epoch logging === + precision, recall, f1, _ = precision_recall_fscore_support(targets, preds, average='binary') + print(f"Epoch {epoch+1}/{epochs}: Loss={total_loss/len(dataloader):.4f}") + print(f"Precision: {precision:.4f} Recall: {recall:.4f} F1: {f1:.4f}") + + # Save after each epoch + checkpoint_path = os.path.join(save_dir, f"clip_rope_epoch{epoch+1}.pth") + torch.save(model.transformer.state_dict(), checkpoint_path) + print(f"[Checkpoint] Saved model after epoch {epoch+1}") + +# === Final Save === +torch.save(model.transformer.state_dict(), "./clip_rope_finetuned_final.pth") +print("[Final Save] Fine-tuned text encoder saved!") From ff6faef35b077cb59a109008093581b50e3ab542 Mon Sep 17 00:00:00 2001 From: rganeshk Date: Sat, 26 Apr 2025 23:46:47 +0000 Subject: [PATCH 3/5] Added finetuning script --- scripts/finetune.py | 109 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 scripts/finetune.py diff --git a/scripts/finetune.py b/scripts/finetune.py new file mode 100644 index 000000000..fe2fec130 --- /dev/null +++ b/scripts/finetune.py @@ -0,0 +1,109 @@ +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from datasets import load_dataset +from sklearn.metrics import precision_recall_fscore_support +import torch.nn.functional as F + +from ldm.modules.encoders.modules import FrozenCLIPEmbedder + +# === Config === +device = "cuda" if torch.cuda.is_available() else "cpu" +batch_size = 32 +epochs = 3 +lr = 1e-5 +max_length = 77 +save_dir = "./checkpoints" +os.makedirs(save_dir, exist_ok=True) +save_every_n_steps = 1000 # Save every 1000 batches + +# === Dataset === +class CocoCountingDataset(torch.utils.data.Dataset): + def __init__(self, split="train", tokenizer=None, max_length=77): + self.dataset = load_dataset("conceptual_captions", split=split) + self.tokenizer = tokenizer + self.max_length = max_length + self.number_words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten'] + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + caption = self.dataset[idx]['caption'].lower() + label = int(any(word in caption for word in self.number_words)) # label 1 if counting word exists + + if label == 0: + caption = "one object." + + encoding = self.tokenizer(caption, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt") + input_ids = encoding["input_ids"].squeeze(0) + attention_mask = encoding["attention_mask"].squeeze(0) + return input_ids, attention_mask, label + +# === Model === +model = FrozenCLIPEmbedder(version="openai/clip-vit-large-patch14", device=device, max_length=max_length) + +for param in model.transformer.parameters(): + param.requires_grad = True + +model = model.to(device) +optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.transformer.parameters()), lr=lr) + +# === Dataloader === +dataset = CocoCountingDataset(split="train", tokenizer=model.tokenizer, max_length=max_length) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) + +# === Training === +model.train() +global_step = 0 +for epoch in range(epochs): + total_loss = 0 + preds, targets = [], [] + + for batch_idx, (input_ids, attention_mask, labels) in enumerate(tqdm(dataloader)): + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + labels = labels.to(device) + + outputs = model.transformer(input_ids=input_ids, attention_mask=attention_mask) + embeddings = outputs.last_hidden_state + + loss = torch.mean(torch.norm(embeddings, dim=-1)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + + # Mock "classification" for precision/recall: use embedding norm as pseudo-score + scores = torch.norm(embeddings[:, 0, :], dim=-1) # CLS token norm + pred_labels = (scores > scores.mean()).long() + + preds.extend(pred_labels.cpu().tolist()) + targets.extend(labels.cpu().tolist()) + + global_step += 1 + + # === Save checkpoint mid-epoch + if global_step % save_every_n_steps == 0: + checkpoint_path = os.path.join(save_dir, f"clip_rope_step{global_step}.pth") + torch.save(model.transformer.state_dict(), checkpoint_path) + print(f"[Checkpoint] Saved at step {global_step}") + + # === End of epoch logging === + precision, recall, f1, _ = precision_recall_fscore_support(targets, preds, average='binary') + print(f"Epoch {epoch+1}/{epochs}: Loss={total_loss/len(dataloader):.4f}") + print(f"Precision: {precision:.4f} Recall: {recall:.4f} F1: {f1:.4f}") + + # Save after each epoch + checkpoint_path = os.path.join(save_dir, f"clip_rope_epoch{epoch+1}.pth") + torch.save(model.transformer.state_dict(), checkpoint_path) + print(f"[Checkpoint] Saved model after epoch {epoch+1}") + +# === Final Save === +torch.save(model.transformer.state_dict(), "./clip_rope_finetuned_final.pth") +print("[Final Save] Fine-tuned text encoder saved!") From 47e21b613cd0e41328df2e5e29fc6cde41f572ea Mon Sep 17 00:00:00 2001 From: Aiza Usman Date: Sat, 26 Apr 2025 21:29:06 -0400 Subject: [PATCH 4/5] added logic for bounded attention --- ldm/modules/bounded_attention.py | 114 ++++++++++++++++ scripts/eval_bounded_attention.py | 213 ++++++++++++++++++++++++++++++ scripts/txt2img.py | 161 ++++++++++++---------- 3 files changed, 421 insertions(+), 67 deletions(-) create mode 100644 ldm/modules/bounded_attention.py create mode 100644 scripts/eval_bounded_attention.py diff --git a/ldm/modules/bounded_attention.py b/ldm/modules/bounded_attention.py new file mode 100644 index 000000000..427507780 --- /dev/null +++ b/ldm/modules/bounded_attention.py @@ -0,0 +1,114 @@ +""" +Bounded Attention patch for Stable‑Diffusion v1 UNet +--------------------------------------------------- +This module injects a lightweight masking step into all +*cross‑attention* and *self‑attention* layers used during sampling so +that distinct groups of prompt tokens ("subjects") can only attend to +keys originating from their own group plus an optional background +bucket. + +Implementation follows "Be Yourself: Bounded Attention for Multi‑Subject T2I" (Dahary et al. 2024). +""" +from __future__ import annotations +import contextlib +from dataclasses import dataclass +from typing import List, Optional, Tuple +import torch +import torch.nn.functional as F + + +@dataclass +class SubjectMask: + """Half‑open token span [start, end) identifying one subject.""" + start: int + end: int + + def slice(self, length: int) -> slice: + # clip to sequence length to avoid IndexError + return slice(max(0, self.start), min(length, self.end)) + + +def _masked_softmax(attn: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: + """Softmax with a `0/‑inf` mask on the *key* dimension (`dim`).""" + attn = attn.masked_fill_(~mask, float("-inf")) + return F.softmax(attn, dim=dim, dtype=torch.float32) + + +def _build_key_mask(seq_len: int, subjects: List[SubjectMask], device) -> torch.Tensor: + """Return a [Nsubj, 1, 1, seq_len] boolean tensor indicating where each subject may attend.""" + masks = [] + full = torch.zeros(seq_len, dtype=torch.bool, device=device) + for sm in subjects: + m = full.clone() + m[sm.slice(seq_len)] = True # allow attention within subject span + masks.append(m) + return torch.stack(masks)[:, None, None, :] # [N,1,1,L] + + +_patch_handle: Optional[Tuple] = None + +def enable_bounded_attention(model, subjects: List[SubjectMask]): + """Inject bounded‑attention into *all* CrossAttention blocks of ``model``. + + Args: + model: ``ldm.models.diffusion.ddpm.LatentDiffusion`` or similar UNet wrapper. + subjects: list of ``SubjectMask`` specifying *text* token spans for each subject. + """ + global _patch_handle + if _patch_handle is not None: + raise RuntimeError("Bounded attention already enabled; call disable_bounded_attention() first") + + # capture original forward fn to restore later + from ldm.modules.attention import CrossAttention + original_forward = CrossAttention.forward + + def forward_patched(self, x, context=None, mask=None): + out = original_forward(self, x, context, mask) + if context is None: + return out # self‑attention inside UNet; leave unchanged for now + + # context shape: [B, Lctx, C]; only first B entries refer to prompt tokens + B, Lctx, _ = context.shape + device = context.device + + # Build or cache key mask once per batch size / seq_len + if not hasattr(self, "_ba_mask") or self._ba_mask.shape[-1] != Lctx: + self._ba_mask = _build_key_mask(Lctx, subjects, device) # [N,1,1,Lctx] + + # queries: [B*H, Lq, Dh]; keys: [B*H, Lk, Dh]; attn_scores: [B*H, Lq, Lk] + q, k, v = self.to_qkv(x, context) + h = self.heads + q, k, v = map(lambda t: t.reshape(B, -1, h, self.head_dim).transpose(1, 2), (q, k, v)) + attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale + + # broadcast subject mask to match attn_scores; assume each subject occupies equal batch chunk + # For B==1 this is trivial; for larger batch user must supply matching spans per sample. + subj_mask = self._ba_mask[0] # [1,1,Lk] + attn_probs = _masked_softmax(attn_scores, subj_mask, dim=-1) + out = torch.matmul(attn_probs, v) + out = out.transpose(1, 2).reshape(B, -1, h * self.head_dim) + return self.to_out(out) + + # apply patch + CrossAttention.forward = forward_patched + _patch_handle = (CrossAttention, original_forward) + + +def disable_bounded_attention(): + """Restore original CrossAttention implementation.""" + global _patch_handle + if _patch_handle is None: + return + cls, orig = _patch_handle + cls.forward = orig + _patch_handle = None + + +@contextlib.contextmanager +def bounded_attention(model, subjects: List[SubjectMask]): + """Context manager for temporarily enabling bounded attention.""" + enable_bounded_attention(model, subjects) + try: + yield + finally: + disable_bounded_attention() \ No newline at end of file diff --git a/scripts/eval_bounded_attention.py b/scripts/eval_bounded_attention.py new file mode 100644 index 000000000..d8e9e50fc --- /dev/null +++ b/scripts/eval_bounded_attention.py @@ -0,0 +1,213 @@ +import argparse +import os +import torch +import numpy as np +from PIL import Image +from tqdm import tqdm +from transformers import CLIPProcessor, CLIPModel +import torchvision.transforms as transforms +from torchvision.models.detection import fasterrcnn_resnet50_fpn +from torchvision.transforms.functional import to_tensor +import json + +def load_clip(): + """Load CLIP model for semantic similarity scoring.""" + model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") + return model, processor + +def load_object_detector(): + """Load FasterRCNN for object detection.""" + model = fasterrcnn_resnet50_fpn(pretrained=True) + model.eval() + return model + +def compute_clip_score(model, processor, image_path, text): + """Compute CLIP score between image and text.""" + image = Image.open(image_path) + inputs = processor(images=image, text=text, return_tensors="pt", padding=True) + outputs = model(**inputs) + return outputs.logits_per_image.item() + +def compute_clip_count_accuracy(detector, image_path, target_classes, expected_counts, confidence_threshold=0.7): + """Compute counting accuracy for specific object classes.""" + image = Image.open(image_path) + transform = transforms.Compose([transforms.ToTensor()]) + image_tensor = transform(image).unsqueeze(0) + + with torch.no_grad(): + predictions = detector(image_tensor) + + pred_classes = predictions[0]['labels'].numpy() + pred_scores = predictions[0]['scores'].numpy() + + # Filter predictions by confidence threshold + confident_mask = pred_scores > confidence_threshold + confident_preds = pred_classes[confident_mask] + + # Count occurrences of each target class + actual_counts = {} + for cls in target_classes: + actual_counts[str(cls)] = np.sum(confident_preds == cls) + + # Calculate counting accuracy + count_accuracies = {} + for cls_str, expected in expected_counts.items(): + actual = actual_counts.get(cls_str, 0) + accuracy = 1.0 - abs(expected - actual) / max(expected, actual) if max(expected, actual) > 0 else 1.0 + count_accuracies[cls_str] = { + 'expected': expected, + 'actual': actual, + 'accuracy': accuracy + } + + return count_accuracies + +def compute_object_metrics(detector, image_path, target_classes, confidence_threshold=0.7): + """Compute precision, recall, F1 for specific object classes.""" + image = Image.open(image_path) + transform = transforms.Compose([transforms.ToTensor()]) + image_tensor = transform(image).unsqueeze(0) + + with torch.no_grad(): + predictions = detector(image_tensor) + + pred_classes = predictions[0]['labels'].numpy() + pred_scores = predictions[0]['scores'].numpy() + + # Filter predictions by confidence threshold + confident_preds = pred_classes[pred_scores > confidence_threshold] + + # Count predictions for target classes + true_positives = sum(1 for c in confident_preds if c in target_classes) + false_positives = len(confident_preds) - true_positives + false_negatives = len(target_classes) - true_positives + + precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 + recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0 + f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + return { + 'precision': precision, + 'recall': recall, + 'f1': f1, + 'true_positives': true_positives, + 'false_positives': false_positives, + 'false_negatives': false_negatives + } + +def convert_to_json_serializable(obj): + """Convert numpy types to Python native types for JSON serialization.""" + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, dict): + return {k: convert_to_json_serializable(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [convert_to_json_serializable(item) for item in obj] + return obj + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--samples_dir', type=str, required=True, help='Directory containing generated samples') + parser.add_argument('--prompt', type=str, required=True, help='Original prompt used for generation') + parser.add_argument('--target_objects', type=str, required=True, help='Comma-separated list of target object classes') + parser.add_argument('--expected_counts', type=str, required=True, help='Comma-separated list of expected counts (same order as target_objects)') + args = parser.parse_args() + + # Load models + print("Loading CLIP model...") + clip_model, clip_processor = load_clip() + + print("Loading object detector...") + detector = load_object_detector() + + # Parse target objects and expected counts + target_objects = [obj.strip() for obj in args.target_objects.split(',')] + expected_counts_list = [int(c.strip()) for c in args.expected_counts.split(',')] + + # Map target objects to COCO class indices + coco_class_mapping = { + 'cup': 47, + 'plate': 48, + # Add more mappings as needed + } + target_classes = [coco_class_mapping[obj] for obj in target_objects if obj in coco_class_mapping] + expected_counts = {str(coco_class_mapping[obj]): count + for obj, count in zip(target_objects, expected_counts_list) + if obj in coco_class_mapping} + + # Evaluate all samples + results = [] + clip_scores = [] + count_accuracies = [] + + print("Evaluating samples...") + for image_file in tqdm(os.listdir(args.samples_dir)): + if not image_file.endswith(('.png', '.jpg', '.jpeg')): + continue + + image_path = os.path.join(args.samples_dir, image_file) + + # Compute CLIP score + clip_score = compute_clip_score(clip_model, clip_processor, image_path, args.prompt) + clip_scores.append(clip_score) + + # Compute object detection metrics + metrics = compute_object_metrics(detector, image_path, target_classes) + + # Compute counting accuracy + count_acc = compute_clip_count_accuracy(detector, image_path, target_classes, expected_counts) + count_accuracies.append(count_acc) + + results.append({ + 'image': image_file, + 'clip_score': clip_score, + 'count_accuracy': convert_to_json_serializable(count_acc), + **convert_to_json_serializable(metrics) + }) + + # Compute average metrics + avg_count_acc = {} + for cls in target_classes: + cls_str = str(cls) + accuracies = [r['count_accuracy'][cls_str]['accuracy'] for r in results] + avg_count_acc[cls_str] = float(np.mean(accuracies)) + + avg_metrics = { + 'avg_clip_score': float(np.mean(clip_scores)), + 'avg_precision': float(np.mean([r['precision'] for r in results])), + 'avg_recall': float(np.mean([r['recall'] for r in results])), + 'avg_f1': float(np.mean([r['f1'] for r in results])), + 'avg_count_accuracy': avg_count_acc + } + + # Save results + output_file = os.path.join(args.samples_dir, 'evaluation_results.json') + with open(output_file, 'w') as f: + json.dump({ + 'prompt': args.prompt, + 'target_objects': target_objects, + 'expected_counts': convert_to_json_serializable(expected_counts), + 'individual_results': results, + 'average_metrics': avg_metrics + }, f, indent=2) + + # Print summary + print("\nEvaluation Results:") + print(f"Average CLIP Score: {avg_metrics['avg_clip_score']:.4f}") + print(f"Average Precision: {avg_metrics['avg_precision']:.4f}") + print(f"Average Recall: {avg_metrics['avg_recall']:.4f}") + print(f"Average F1 Score: {avg_metrics['avg_f1']:.4f}") + print("\nCounting Accuracy per Class:") + for cls in target_classes: + cls_str = str(cls) + obj_name = [k for k, v in coco_class_mapping.items() if str(v) == cls_str][0] + print(f" {obj_name}: {avg_metrics['avg_count_accuracy'][cls_str]:.4f}") + print(f"\nDetailed results saved to: {output_file}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/txt2img.py b/scripts/txt2img.py index bc3864043..c01ffce5a 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -5,7 +5,7 @@ from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm, trange -from imwatermark import WatermarkEncoder +# from imwatermark import WatermarkEncoder from itertools import islice from einops import rearrange from torchvision.utils import make_grid @@ -18,6 +18,7 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler +from ldm.modules.bounded_attention import enable_bounded_attention, SubjectMask, bounded_attention from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor @@ -66,12 +67,12 @@ def load_model_from_config(config, ckpt, verbose=False): return model -def put_watermark(img, wm_encoder=None): - if wm_encoder is not None: - img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) - img = wm_encoder.encode(img, 'dwtDct') - img = Image.fromarray(img[:, :, ::-1]) - return img +# def put_watermark(img, wm_encoder=None): +# if wm_encoder is not None: +# img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) +# img = wm_encoder.encode(img, 'dwtDct') +# img = Image.fromarray(img[:, :, ::-1]) +# return img def load_replacement(x): @@ -232,6 +233,11 @@ def main(): choices=["full", "autocast"], default="autocast" ) + parser.add_argument( + "--subject_ranges", + type=str, + help="Comma-separated list of token ranges for subjects in format 'start1-end1,start2-end2,...'", + ) opt = parser.parse_args() if opt.laion400m: @@ -240,6 +246,18 @@ def main(): opt.ckpt = "models/ldm/text2img-large/model.ckpt" opt.outdir = "outputs/txt2img-samples-laion400m" + if opt.subject_ranges: + # Parse subject ranges from string like "2-5,5-8" into list of SubjectMask objects + try: + ranges = [tuple(map(int, r.split('-'))) for r in opt.subject_ranges.split(',')] + subject_masks = [SubjectMask(start, end) for start, end in ranges] + except: + print("Error parsing subject ranges. Format should be 'start1-end1,start2-end2,...'") + print("Example: --subject_ranges '2-5,5-8' for two subjects") + sys.exit(1) + else: + subject_masks = None + seed_everything(opt.seed) config = OmegaConf.load(f"{opt.config}") @@ -259,9 +277,9 @@ def main(): outpath = opt.outdir print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") - wm = "StableDiffusionV1" - wm_encoder = WatermarkEncoder() - wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + # wm = "StableDiffusionV1" + # wm_encoder = WatermarkEncoder() + # wm_encoder.set_watermark('bytes', wm.encode('utf-8')) batch_size = opt.n_samples n_rows = opt.n_rows if opt.n_rows > 0 else batch_size @@ -286,63 +304,72 @@ def main(): start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) precision_scope = autocast if opt.precision=="autocast" else nullcontext - with torch.no_grad(): - with precision_scope("cuda"): - with model.ema_scope(): - tic = time.time() - all_samples = list() - for n in trange(opt.n_iter, desc="Sampling"): - for prompts in tqdm(data, desc="data"): - uc = None - if opt.scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - c = model.get_learned_conditioning(prompts) - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - x_T=start_code) - - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() - - x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim) - - x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) - - if not opt.skip_save: - for x_sample in x_checked_image_torch: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - img = Image.fromarray(x_sample.astype(np.uint8)) - img = put_watermark(img, wm_encoder) - img.save(os.path.join(sample_path, f"{base_count:05}.png")) - base_count += 1 - - if not opt.skip_grid: - all_samples.append(x_checked_image_torch) - - if not opt.skip_grid: - # additionally, save as grid - grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=n_rows) - - # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - img = Image.fromarray(grid.astype(np.uint8)) - img = put_watermark(img, wm_encoder) - img.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) - grid_count += 1 - - toc = time.time() + + # Enable bounded attention if subject ranges were provided + if subject_masks: + print(f"Enabling bounded attention with subjects: {opt.subject_ranges}") + context_manager = bounded_attention(model, subject_masks) + else: + context_manager = nullcontext() + + with context_manager: + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + tic = time.time() + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim) + + x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + + if not opt.skip_save: + for x_sample in x_checked_image_torch: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + # img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + + if not opt.skip_grid: + all_samples.append(x_checked_image_torch) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + img = Image.fromarray(grid.astype(np.uint8)) + # img = put_watermark(img, wm_encoder) + img.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + toc = time.time() print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.") From c9913373b5e753f8be01794b169c5775d73bef98 Mon Sep 17 00:00:00 2001 From: Aiza Usman Date: Sat, 26 Apr 2025 23:36:42 -0400 Subject: [PATCH 5/5] edited logic for masks --- ldm/modules/bounded_attention.py | 181 ++++++++++++++++++------------- 1 file changed, 104 insertions(+), 77 deletions(-) diff --git a/ldm/modules/bounded_attention.py b/ldm/modules/bounded_attention.py index 427507780..e8c2eb420 100644 --- a/ldm/modules/bounded_attention.py +++ b/ldm/modules/bounded_attention.py @@ -1,114 +1,141 @@ """ -Bounded Attention patch for Stable‑Diffusion v1 UNet ---------------------------------------------------- -This module injects a lightweight masking step into all -*cross‑attention* and *self‑attention* layers used during sampling so -that distinct groups of prompt tokens ("subjects") can only attend to -keys originating from their own group plus an optional background -bucket. - -Implementation follows "Be Yourself: Bounded Attention for Multi‑Subject T2I" (Dahary et al. 2024). +Bounded Attention patch for Stable-Diffusion v1 UNet +(implements Dahary et al., 2024) """ from __future__ import annotations import contextlib from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import List, Optional import torch import torch.nn.functional as F - @dataclass class SubjectMask: - """Half‑open token span [start, end) identifying one subject.""" start: int - end: int + end: int # half-open [start, end) - def slice(self, length: int) -> slice: - # clip to sequence length to avoid IndexError - return slice(max(0, self.start), min(length, self.end)) + def slice(self, L: int) -> slice: + return slice(max(0, self.start), min(L, self.end)) -def _masked_softmax(attn: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: - """Softmax with a `0/‑inf` mask on the *key* dimension (`dim`).""" - attn = attn.masked_fill_(~mask, float("-inf")) - return F.softmax(attn, dim=dim, dtype=torch.float32) +# ---------- helpers --------------------------------------------------------- +def _build_key_mask(L: int, subjects: List[SubjectMask], device) -> torch.Tensor: + """ + returns [Nsubj+1, 1, 1, L] bool + - first N slices = subjects + - last slice = background (everything not in any subject span) + """ + full = torch.zeros(L, dtype=torch.bool, device=device) + subj_masks = [] + covered = torch.zeros_like(full) -def _build_key_mask(seq_len: int, subjects: List[SubjectMask], device) -> torch.Tensor: - """Return a [Nsubj, 1, 1, seq_len] boolean tensor indicating where each subject may attend.""" - masks = [] - full = torch.zeros(seq_len, dtype=torch.bool, device=device) for sm in subjects: m = full.clone() - m[sm.slice(seq_len)] = True # allow attention within subject span - masks.append(m) - return torch.stack(masks)[:, None, None, :] # [N,1,1,L] + m[sm.slice(L)] = True + covered |= m + subj_masks.append(m) + # background “bucket” + background = ~covered + return torch.stack([*subj_masks, background])[:, None, None, :] # [N+1,1,1,L] -_patch_handle: Optional[Tuple] = None -def enable_bounded_attention(model, subjects: List[SubjectMask]): - """Inject bounded‑attention into *all* CrossAttention blocks of ``model``. +def _safe_softmax(attn: torch.Tensor, mask: torch.Tensor, dim=-1) -> torch.Tensor: + """ + softmax with -inf masking that **guarantees** each row has ≥1 valid key. + if a row would be all -inf we instead fall back to an un-masked softmax + for that row only (uniform attention ≈ no harm, avoids nans). + """ + max_neg = -torch.finfo(attn.dtype).max + attn = attn.masked_fill(~mask, max_neg) - Args: - model: ``ldm.models.diffusion.ddpm.LatentDiffusion`` or similar UNet wrapper. - subjects: list of ``SubjectMask`` specifying *text* token spans for each subject. + # rows where everything is masked + all_masked = (mask.sum(dim=dim, keepdim=True) == 0) + if all_masked.any(): + attn = attn.masked_fill(all_masked, 0.0) + + return F.softmax(attn, dim=dim, dtype=torch.float32) + + +# ---------- monkey-patch machinery ----------------------------------------- + +_patch: Optional[tuple] = None + +def enable_bounded_attention(model, subjects: List[SubjectMask]): """ - global _patch_handle - if _patch_handle is not None: - raise RuntimeError("Bounded attention already enabled; call disable_bounded_attention() first") - - # capture original forward fn to restore later - from ldm.modules.attention import CrossAttention - original_forward = CrossAttention.forward - - def forward_patched(self, x, context=None, mask=None): - out = original_forward(self, x, context, mask) - if context is None: - return out # self‑attention inside UNet; leave unchanged for now - - # context shape: [B, Lctx, C]; only first B entries refer to prompt tokens - B, Lctx, _ = context.shape - device = context.device - - # Build or cache key mask once per batch size / seq_len - if not hasattr(self, "_ba_mask") or self._ba_mask.shape[-1] != Lctx: - self._ba_mask = _build_key_mask(Lctx, subjects, device) # [N,1,1,Lctx] - - # queries: [B*H, Lq, Dh]; keys: [B*H, Lk, Dh]; attn_scores: [B*H, Lq, Lk] - q, k, v = self.to_qkv(x, context) - h = self.heads - q, k, v = map(lambda t: t.reshape(B, -1, h, self.head_dim).transpose(1, 2), (q, k, v)) - attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale - - # broadcast subject mask to match attn_scores; assume each subject occupies equal batch chunk - # For B==1 this is trivial; for larger batch user must supply matching spans per sample. - subj_mask = self._ba_mask[0] # [1,1,Lk] - attn_probs = _masked_softmax(attn_scores, subj_mask, dim=-1) - out = torch.matmul(attn_probs, v) - out = out.transpose(1, 2).reshape(B, -1, h * self.head_dim) + Enable bounded attention on **all** CrossAttention layers in `model`. + """ + global _patch + if _patch is not None: + raise RuntimeError("already enabled") + + from ldm.modules.attention import CrossAttention # import locally + orig_forward = CrossAttention.forward + + def forward_ba(self, x, context=None, mask=None): + h = self.heads + context = x if context is None else context # self-attention + + B, Lq, _ = x.shape + Lk = context.shape[1] + device = context.device + + # build / cache masks + if (not hasattr(self, "_ba_kmask") + or self._ba_kmask.shape[-1] != Lk): + self._ba_kmask = _build_key_mask(Lk, subjects, device) # [N+1,1,1,Lk] + + # decide, **per query token**, which bucket to use + # rule of thumb: token ∈ subject_i → bucket i + # else → background bucket (−1 index) + bucket_ids = torch.full((Lk,), len(subjects), device=device) + for i, sm in enumerate(subjects): + bucket_ids[sm.slice(Lk)] = i # assign ids + + # projections + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + + dim_head = q.shape[-1] // h + q, k, v = map(lambda t: t.view(B, -1, h, dim_head).transpose(1, 2), + (q, k, v)) # (B,h,Len,dh) + + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # (B,h,Lq,Lk) + + # broadcast key-mask to (B,h,Lq,Lk) by picking the right bucket for each query + # 1. Union-of-subject + background bucket → shape (1,1,1,Lk) + kmask = self._ba_kmask.any(0, keepdim=True) # (1,1,1,Lk) + + # 2. Bring in SD’s own mask (if it exists) + if mask is not None: # mask: (B,1,1,Lk) + kmask = kmask & mask + + # 3. Broadcast to (B,h,Lq,Lk) automatically + probs = _safe_softmax(attn, kmask, dim=-1) + + out = torch.matmul(probs, v) # (B,h,Lq,dh) + out = out.transpose(1, 2).reshape(B, Lq, h * dim_head) return self.to_out(out) - # apply patch - CrossAttention.forward = forward_patched - _patch_handle = (CrossAttention, original_forward) + CrossAttention.forward = forward_ba + _patch = (CrossAttention, orig_forward) def disable_bounded_attention(): - """Restore original CrossAttention implementation.""" - global _patch_handle - if _patch_handle is None: + global _patch + if _patch is None: return - cls, orig = _patch_handle - cls.forward = orig - _patch_handle = None + cls, orig_fwd = _patch + cls.forward = orig_fwd + _patch = None @contextlib.contextmanager def bounded_attention(model, subjects: List[SubjectMask]): - """Context manager for temporarily enabling bounded attention.""" enable_bounded_attention(model, subjects) try: yield finally: - disable_bounded_attention() \ No newline at end of file + disable_bounded_attention()