diff --git a/README.md b/README.md index 1f118aa..f7791f6 100644 --- a/README.md +++ b/README.md @@ -1 +1,83 @@ -# DSB \ No newline at end of file +# DSB + +一个简单的 **Diffusion Schrödinger Bridge**(扩散薛定谔桥)PyTorch 示例实现,使用**公开数据集**(`torchvision`)作为输入,不再生成合成数据集。 + +包含内容: + +- 数据集输入(MNIST / FashionMNIST / CIFAR10) +- 模型实现(时间条件 MLP) +- 模型训练 +- 验证集评估指标 + +## 1. 环境依赖 + +```bash +pip install torch torchvision +``` + +## 2. 数据构造方式(公开数据集) + +脚本会从公开视觉数据集里选两个类别,分别作为桥两端: + +- `x0`: `source_class` 类别样本 +- `x1`: `target_class` 类别样本 + +然后随机配对,形成 `(x0, x1)` 训练样本。 + +支持数据集: + +- `mnist` +- `fashionmnist` +- `cifar10` + +## 3. 训练示例 + +### 3.1 MNIST: 数字 1 -> 数字 7 + +```bash +python dsb_torch.py \ + --dataset mnist \ + --source_class 1 \ + --target_class 7 \ + --max_pairs 5000 \ + --epochs 20 +``` + +### 3.2 CIFAR10: airplane(0) -> ship(8) + +```bash +python dsb_torch.py \ + --dataset cifar10 \ + --source_class 0 \ + --target_class 8 \ + --max_pairs 5000 \ + --epochs 20 +``` + +## 4. 输出指标(验证集) + +每个 epoch 输出: + +- `val_loss`:验证总损失(端点重建平方误差总和 / 样本数) +- `endpoint_mse`:端点重建 MSE(`x0/x1`) +- `endpoint_mae`:端点重建 MAE(`x0/x1`) +- `path_mse`:中间时刻路径均值重建 MSE + +示例: + +```text +Epoch 005 | train_loss=0.111111 | val_loss=85.432100 | endpoint_mse=0.054321 | endpoint_mae=0.151200 | path_mse=0.020345 +``` + +## 5. 方法说明(简化版) + +- 对配对样本 `(x0, x1)`,随机采样时间 `t ~ U(0,1)` +- 构造桥噪声点: + + `x_t = (1-t)x0 + t x1 + sigma * sqrt(t(1-t)) * eps` +- 模型输入 `(x_t, t)`,输出 `(x0_hat, x1_hat)` +- 训练目标: + + `MSE(x0_hat, x0) + MSE(x1_hat, x1)` + +> 说明:本实现是教学用最小版本,用来演示“公开数据集输入 + DSB 训练/验证流程”。 diff --git a/dsb_torch.py b/dsb_torch.py new file mode 100644 index 0000000..64f70e9 --- /dev/null +++ b/dsb_torch.py @@ -0,0 +1,255 @@ +import argparse +from dataclasses import dataclass +from typing import Tuple + +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset, random_split +from torchvision import datasets, transforms + + +class BridgePairDataset(Dataset): + """Paired endpoints (x0, x1) for Schrödinger bridge training.""" + + def __init__(self, x0: torch.Tensor, x1: torch.Tensor): + if x0.shape != x1.shape: + raise ValueError(f"x0 and x1 must have same shape, got {x0.shape} and {x1.shape}") + self.x0 = x0.float() + self.x1 = x1.float() + + def __len__(self) -> int: + return self.x0.shape[0] + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + return self.x0[idx], self.x1[idx] + + +class PublicVisionBridgeDatasetBuilder: + """Build paired endpoints from a public torchvision dataset by class-to-class pairing.""" + + _SUPPORTED = { + "mnist": datasets.MNIST, + "fashionmnist": datasets.FashionMNIST, + "cifar10": datasets.CIFAR10, + } + + @classmethod + def build( + cls, + root: str, + dataset_name: str, + source_class: int, + target_class: int, + max_pairs: int, + seed: int, + train_split: bool = True, + ) -> BridgePairDataset: + dataset_name = dataset_name.lower() + if dataset_name not in cls._SUPPORTED: + raise ValueError(f"Unsupported dataset '{dataset_name}'. Choose from {list(cls._SUPPORTED.keys())}") + + transform = transforms.ToTensor() + dataset_cls = cls._SUPPORTED[dataset_name] + ds = dataset_cls(root=root, train=train_split, download=True, transform=transform) + + imgs = [] + labels = [] + for img, label in ds: + imgs.append(img) + labels.append(int(label)) + + x = torch.stack(imgs, dim=0) + y = torch.tensor(labels, dtype=torch.long) + + src_idx = (y == source_class).nonzero(as_tuple=False).squeeze(1) + tgt_idx = (y == target_class).nonzero(as_tuple=False).squeeze(1) + + if src_idx.numel() == 0 or tgt_idx.numel() == 0: + raise ValueError("source_class or target_class has no samples in selected split") + + n_pairs = min(src_idx.numel(), tgt_idx.numel(), max_pairs) + gen = torch.Generator().manual_seed(seed) + src_perm = src_idx[torch.randperm(src_idx.numel(), generator=gen)[:n_pairs]] + tgt_perm = tgt_idx[torch.randperm(tgt_idx.numel(), generator=gen)[:n_pairs]] + + x0 = x[src_perm].flatten(start_dim=1) + x1 = x[tgt_perm].flatten(start_dim=1) + return BridgePairDataset(x0, x1) + + +class TimeEmbedding(nn.Module): + def __init__(self, emb_dim: int): + super().__init__() + self.emb_dim = emb_dim + + def forward(self, t: torch.Tensor) -> torch.Tensor: + half = self.emb_dim // 2 + freqs = torch.exp(torch.linspace(0, torch.log(torch.tensor(1000.0)), half, device=t.device, dtype=t.dtype)) + args = t * freqs.unsqueeze(0) + emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) + if self.emb_dim % 2 == 1: + emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=-1) + return emb + + +class BridgeDenoiser(nn.Module): + def __init__(self, data_dim: int, hidden_dim: int = 256, time_dim: int = 64): + super().__init__() + self.time_emb = TimeEmbedding(time_dim) + self.net = nn.Sequential( + nn.Linear(data_dim + time_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, 2 * data_dim), + ) + + def forward(self, x_t: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + te = self.time_emb(t) + pred = self.net(torch.cat([x_t, te], dim=-1)) + return torch.chunk(pred, chunks=2, dim=-1) + + +@dataclass +class TrainConfig: + epochs: int = 20 + batch_size: int = 128 + lr: float = 1e-3 + sigma: float = 0.20 + val_ratio: float = 0.2 + seed: int = 42 + + +def sample_noisy_bridge(x0: torch.Tensor, x1: torch.Tensor, sigma: float) -> Tuple[torch.Tensor, torch.Tensor]: + bsz = x0.shape[0] + t = torch.rand(bsz, 1, device=x0.device) + eps = torch.randn_like(x0) + var = torch.clamp(t * (1.0 - t), min=1e-4) + x_t = (1.0 - t) * x0 + t * x1 + sigma * torch.sqrt(var) * eps + return x_t, t + + +def compute_metrics(model: nn.Module, loader: DataLoader, sigma: float, device: torch.device) -> dict: + model.eval() + mse_sum = nn.MSELoss(reduction="sum") + mae_sum = nn.L1Loss(reduction="sum") + total = 0 + endpoint_mse_sum = 0.0 + endpoint_mae_sum = 0.0 + path_mse_sum = 0.0 + + with torch.no_grad(): + for x0, x1 in loader: + x0 = x0.to(device) + x1 = x1.to(device) + x_t, t = sample_noisy_bridge(x0, x1, sigma) + x0_hat, x1_hat = model(x_t, t) + + mu_true = (1.0 - t) * x0 + t * x1 + mu_hat = (1.0 - t) * x0_hat + t * x1_hat + + bsz = x0.shape[0] + total += bsz + endpoint_mse_sum += (mse_sum(x0_hat, x0) + mse_sum(x1_hat, x1)).item() + endpoint_mae_sum += (mae_sum(x0_hat, x0) + mae_sum(x1_hat, x1)).item() + path_mse_sum += mse_sum(mu_hat, mu_true).item() + + dim = loader.dataset[0][0].numel() + return { + "val_loss": endpoint_mse_sum / total, + "endpoint_mse": endpoint_mse_sum / (total * 2 * dim), + "endpoint_mae": endpoint_mae_sum / (total * 2 * dim), + "path_mse": path_mse_sum / (total * dim), + } + + +def train(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, cfg: TrainConfig, device: torch.device): + opt = torch.optim.Adam(model.parameters(), lr=cfg.lr) + mse = nn.MSELoss() + + for epoch in range(1, cfg.epochs + 1): + model.train() + total_loss = 0.0 + total = 0 + + for x0, x1 in train_loader: + x0 = x0.to(device) + x1 = x1.to(device) + x_t, t = sample_noisy_bridge(x0, x1, cfg.sigma) + x0_hat, x1_hat = model(x_t, t) + loss = mse(x0_hat, x0) + mse(x1_hat, x1) + + opt.zero_grad() + loss.backward() + opt.step() + + bsz = x0.shape[0] + total_loss += loss.item() * bsz + total += bsz + + metrics = compute_metrics(model, val_loader, cfg.sigma, device) + print( + f"Epoch {epoch:03d} | train_loss={total_loss / total:.6f} | " + f"val_loss={metrics['val_loss']:.6f} | endpoint_mse={metrics['endpoint_mse']:.6f} | " + f"endpoint_mae={metrics['endpoint_mae']:.6f} | path_mse={metrics['path_mse']:.6f}" + ) + + +def main(): + parser = argparse.ArgumentParser(description="Diffusion Schrödinger Bridge with public torchvision datasets") + parser.add_argument("--data_root", type=str, default="./data") + parser.add_argument("--dataset", type=str, default="mnist", choices=["mnist", "fashionmnist", "cifar10"]) + parser.add_argument("--source_class", type=int, default=1, help="x0 endpoint class id") + parser.add_argument("--target_class", type=int, default=7, help="x1 endpoint class id") + parser.add_argument("--max_pairs", type=int, default=5000) + parser.add_argument("--use_test_split", action="store_true", help="Use test split instead of train split") + parser.add_argument("--epochs", type=int, default=20) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--sigma", type=float, default=0.20) + parser.add_argument("--val_ratio", type=float, default=0.2) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + torch.manual_seed(args.seed) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + dataset = PublicVisionBridgeDatasetBuilder.build( + root=args.data_root, + dataset_name=args.dataset, + source_class=args.source_class, + target_class=args.target_class, + max_pairs=args.max_pairs, + seed=args.seed, + train_split=not args.use_test_split, + ) + + n_val = int(len(dataset) * args.val_ratio) + n_train = len(dataset) - n_val + train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(args.seed)) + + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) + val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False) + + model = BridgeDenoiser(data_dim=dataset[0][0].numel()).to(device) + cfg = TrainConfig( + epochs=args.epochs, + batch_size=args.batch_size, + lr=args.lr, + sigma=args.sigma, + val_ratio=args.val_ratio, + seed=args.seed, + ) + + print(f"Using device: {device}") + print( + f"Dataset={args.dataset}, class {args.source_class} -> class {args.target_class}, " + f"pairs={len(dataset)}, train/val={n_train}/{n_val}" + ) + train(model, train_loader, val_loader, cfg, device) + + +if __name__ == "__main__": + main()