From 7519f1949fc47e1987bdb29eb30535bdd044bce3 Mon Sep 17 00:00:00 2001 From: Tumb1eweed <995213108@qq.com> Date: Thu, 12 Feb 2026 18:24:29 +0800 Subject: [PATCH 1/2] =?UTF-8?q?Add=20simple=20PyTorch=20diffusion=20Schr?= =?UTF-8?q?=C3=B6dinger=20bridge=20example?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 71 +++++++++++++- dsb_torch.py | 267 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 337 insertions(+), 1 deletion(-) create mode 100644 dsb_torch.py diff --git a/README.md b/README.md index 1f118aa..08483e4 100644 --- a/README.md +++ b/README.md @@ -1 +1,70 @@ -# DSB \ No newline at end of file +# DSB + +一个简单的 **Diffusion Schrödinger Bridge**(扩散薛定谔桥)PyTorch 示例实现,包含: + +- 数据集输入(CSV) +- 模型实现(时间条件 MLP) +- 训练循环 +- 验证集评估与指标输出 + +## 1. 环境依赖 + +```bash +pip install torch pandas numpy +``` + +## 2. 数据格式 + +CSV 需要包含以下列(按维度扩展): + +- `x0_0, x0_1, ...`:起点样本 +- `x1_0, x1_1, ...`:终点样本 + +例如二维数据: + +```text +x0_0,x0_1,x1_0,x1_1 +-1.2,0.3,2.7,1.1 +... +``` + +## 3. 训练方式 + +### 3.1 使用内置合成数据直接训练 + +```bash +python dsb_torch.py --epochs 20 --batch_size 256 +``` + +### 3.2 先导出合成 CSV,再从 CSV 读取训练 + +```bash +python dsb_torch.py --make_synth_csv data/bridge.csv --n_samples 5000 --dim 2 +python dsb_torch.py --csv data/bridge.csv --epochs 20 +``` + +## 4. 输出指标(验证集) + +每个 epoch 会输出: + +- `val_loss`:验证总损失 +- `endpoint_mse`:端点重建 MSE(x0/x1) +- `endpoint_mae`:端点重建 MAE(x0/x1) +- `path_mse`:中间时刻路径均值重建 MSE + +示例输出: + +```text +Epoch 005 | train_loss=0.123456 | val_loss=0.118901 | endpoint_mse=0.029725 | endpoint_mae=0.128000 | path_mse=0.012345 +``` + +## 5. 方法说明(简化版) + +- 对配对样本 `(x0, x1)`,随机采样时间 `t ~ U(0,1)`。 +- 构造桥上的噪声点: + + `x_t = (1-t)x0 + t x1 + sigma * sqrt(t(1-t)) * eps` +- 模型输入 `x_t, t`,预测 `x0_hat, x1_hat`。 +- 训练损失:`MSE(x0_hat, x0) + MSE(x1_hat, x1)`。 + +这是一个教学用最小示例,便于后续替换为更标准的 score/势函数参数化版本。 diff --git a/dsb_torch.py b/dsb_torch.py new file mode 100644 index 0000000..53a317a --- /dev/null +++ b/dsb_torch.py @@ -0,0 +1,267 @@ +import argparse +import math +import os +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +import pandas as pd +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset, random_split + + +class BridgePairDataset(Dataset): + """Dataset for paired Schrödinger-bridge endpoints (x0, x1).""" + + def __init__(self, x0: torch.Tensor, x1: torch.Tensor): + if x0.shape != x1.shape: + raise ValueError(f"x0 and x1 must have the same shape, got {x0.shape} and {x1.shape}") + self.x0 = x0.float() + self.x1 = x1.float() + + @classmethod + def from_csv(cls, csv_path: str) -> "BridgePairDataset": + df = pd.read_csv(csv_path) + x0_cols = [c for c in df.columns if c.startswith("x0_")] + x1_cols = [c for c in df.columns if c.startswith("x1_")] + if not x0_cols or not x1_cols: + raise ValueError( + "CSV must contain x0_* and x1_* columns, e.g. x0_0,x0_1,...,x1_0,x1_1,..." + ) + x0_cols.sort() + x1_cols.sort() + x0 = torch.tensor(df[x0_cols].to_numpy(), dtype=torch.float32) + x1 = torch.tensor(df[x1_cols].to_numpy(), dtype=torch.float32) + return cls(x0, x1) + + def __len__(self) -> int: + return self.x0.shape[0] + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + return self.x0[idx], self.x1[idx] + + +def make_synthetic_bridge_dataset(n_samples: int = 5000, dim: int = 2, seed: int = 42) -> BridgePairDataset: + """Build toy paired data: x0~N([-2,0],I), x1=R*x0 + shift + noise.""" + rng = np.random.default_rng(seed) + x0 = rng.normal(size=(n_samples, dim)).astype(np.float32) + x0[:, 0] -= 2.0 + + theta = math.pi / 3 + rot = np.array( + [ + [math.cos(theta), -math.sin(theta)], + [math.sin(theta), math.cos(theta)], + ], + dtype=np.float32, + ) + + if dim == 2: + x1 = x0 @ rot.T + else: + # for dim>2, rotate first 2 dims and keep others affine-shifted + x1 = x0.copy() + x1[:, :2] = x0[:, :2] @ rot.T + x1[:, 2:] = 0.7 * x0[:, 2:] + + x1 += np.array([4.0] + [0.5] * (dim - 1), dtype=np.float32) + x1 += 0.1 * rng.normal(size=x1.shape).astype(np.float32) + + return BridgePairDataset(torch.from_numpy(x0), torch.from_numpy(x1)) + + +class TimeEmbedding(nn.Module): + def __init__(self, emb_dim: int): + super().__init__() + self.emb_dim = emb_dim + + def forward(self, t: torch.Tensor) -> torch.Tensor: + """Sinusoidal time embedding.""" + # t: [B,1] + half_dim = self.emb_dim // 2 + freqs = torch.exp( + torch.linspace(0, math.log(1000), half_dim, device=t.device, dtype=t.dtype) + ) + args = t * freqs.unsqueeze(0) + emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) + if self.emb_dim % 2 == 1: + emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=-1) + return emb + + +class BridgeDenoiser(nn.Module): + """Given noisy bridge point x_t and time t, predict (x0, x1).""" + + def __init__(self, data_dim: int, hidden_dim: int = 128, time_dim: int = 32): + super().__init__() + self.time_emb = TimeEmbedding(time_dim) + self.net = nn.Sequential( + nn.Linear(data_dim + time_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, 2 * data_dim), + ) + + def forward(self, x_t: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + te = self.time_emb(t) + pred = self.net(torch.cat([x_t, te], dim=-1)) + x0_hat, x1_hat = torch.chunk(pred, chunks=2, dim=-1) + return x0_hat, x1_hat + + +@dataclass +class TrainConfig: + epochs: int = 50 + batch_size: int = 256 + lr: float = 1e-3 + sigma: float = 0.25 + val_ratio: float = 0.2 + seed: int = 42 + + +def sample_noisy_bridge(x0: torch.Tensor, x1: torch.Tensor, sigma: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz = x0.shape[0] + t = torch.rand(bsz, 1, device=x0.device) + eps = torch.randn_like(x0) + var = torch.clamp(t * (1.0 - t), min=1e-4) + x_t = (1.0 - t) * x0 + t * x1 + sigma * torch.sqrt(var) * eps + return x_t, t, eps + + +def compute_metrics(model: nn.Module, loader: DataLoader, sigma: float, device: torch.device) -> dict: + model.eval() + mse = nn.MSELoss(reduction="sum") + mae = nn.L1Loss(reduction="sum") + total = 0 + loss_sum = 0.0 + endpoint_mse = 0.0 + endpoint_mae = 0.0 + path_mse = 0.0 + + with torch.no_grad(): + for x0, x1 in loader: + x0 = x0.to(device) + x1 = x1.to(device) + x_t, t, _ = sample_noisy_bridge(x0, x1, sigma) + x0_hat, x1_hat = model(x_t, t) + + loss = mse(x0_hat, x0) + mse(x1_hat, x1) + mu_true = (1.0 - t) * x0 + t * x1 + mu_hat = (1.0 - t) * x0_hat + t * x1_hat + + bsz = x0.shape[0] + total += bsz + loss_sum += loss.item() + endpoint_mse += loss.item() + endpoint_mae += (mae(x0_hat, x0) + mae(x1_hat, x1)).item() + path_mse += mse(mu_hat, mu_true).item() + + n_dims = loader.dataset[0][0].numel() * 2 + return { + "val_loss": loss_sum / total, + "endpoint_mse": endpoint_mse / (total * n_dims), + "endpoint_mae": endpoint_mae / (total * n_dims), + "path_mse": path_mse / (total * (n_dims // 2)), + } + + +def train(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, cfg: TrainConfig, device: torch.device): + optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr) + mse = nn.MSELoss() + + for epoch in range(1, cfg.epochs + 1): + model.train() + train_loss = 0.0 + count = 0 + + for x0, x1 in train_loader: + x0 = x0.to(device) + x1 = x1.to(device) + x_t, t, _ = sample_noisy_bridge(x0, x1, cfg.sigma) + + x0_hat, x1_hat = model(x_t, t) + loss = mse(x0_hat, x0) + mse(x1_hat, x1) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + train_loss += loss.item() * x0.shape[0] + count += x0.shape[0] + + metrics = compute_metrics(model, val_loader, cfg.sigma, device) + print( + f"Epoch {epoch:03d} | train_loss={train_loss / count:.6f} | " + f"val_loss={metrics['val_loss']:.6f} | endpoint_mse={metrics['endpoint_mse']:.6f} | " + f"endpoint_mae={metrics['endpoint_mae']:.6f} | path_mse={metrics['path_mse']:.6f}" + ) + + +def save_dataset_csv(dataset: BridgePairDataset, path: str): + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + dim = dataset.x0.shape[1] + data = {} + for i in range(dim): + data[f"x0_{i}"] = dataset.x0[:, i].numpy() + for i in range(dim): + data[f"x1_{i}"] = dataset.x1[:, i].numpy() + pd.DataFrame(data).to_csv(path, index=False) + print(f"Saved synthetic dataset to {path}") + + +def main(): + parser = argparse.ArgumentParser(description="Simple diffusion Schrödinger bridge (PyTorch)") + parser.add_argument("--csv", type=str, default="", help="Optional CSV path with x0_*/x1_* columns") + parser.add_argument("--make_synth_csv", type=str, default="", help="If set, save synthetic dataset to this CSV path and exit") + parser.add_argument("--n_samples", type=int, default=5000) + parser.add_argument("--dim", type=int, default=2) + parser.add_argument("--epochs", type=int, default=50) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--sigma", type=float, default=0.25) + parser.add_argument("--val_ratio", type=float, default=0.2) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + if args.csv: + dataset = BridgePairDataset.from_csv(args.csv) + else: + dataset = make_synthetic_bridge_dataset(args.n_samples, args.dim, args.seed) + + if args.make_synth_csv: + save_dataset_csv(dataset, args.make_synth_csv) + return + + n_val = int(len(dataset) * args.val_ratio) + n_train = len(dataset) - n_val + train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(args.seed)) + + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) + val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False) + + data_dim = dataset[0][0].shape[0] + model = BridgeDenoiser(data_dim).to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) + + cfg = TrainConfig( + epochs=args.epochs, + batch_size=args.batch_size, + lr=args.lr, + sigma=args.sigma, + val_ratio=args.val_ratio, + seed=args.seed, + ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + print(f"Train/Val size: {n_train}/{n_val}") + train(model, train_loader, val_loader, cfg, device) + + +if __name__ == "__main__": + main() From 4235d1e4f85bffbe31e8c762615d6f979a2cd17f Mon Sep 17 00:00:00 2001 From: Tumb1eweed <995213108@qq.com> Date: Thu, 12 Feb 2026 18:56:57 +0800 Subject: [PATCH 2/2] Use public torchvision datasets instead of synthetic data --- README.md | 77 ++++++++++------- dsb_torch.py | 234 ++++++++++++++++++++++++--------------------------- 2 files changed, 156 insertions(+), 155 deletions(-) diff --git a/README.md b/README.md index 08483e4..f7791f6 100644 --- a/README.md +++ b/README.md @@ -1,70 +1,83 @@ # DSB -一个简单的 **Diffusion Schrödinger Bridge**(扩散薛定谔桥)PyTorch 示例实现,包含: +一个简单的 **Diffusion Schrödinger Bridge**(扩散薛定谔桥)PyTorch 示例实现,使用**公开数据集**(`torchvision`)作为输入,不再生成合成数据集。 -- 数据集输入(CSV) +包含内容: + +- 数据集输入(MNIST / FashionMNIST / CIFAR10) - 模型实现(时间条件 MLP) -- 训练循环 -- 验证集评估与指标输出 +- 模型训练 +- 验证集评估指标 ## 1. 环境依赖 ```bash -pip install torch pandas numpy +pip install torch torchvision ``` -## 2. 数据格式 +## 2. 数据构造方式(公开数据集) -CSV 需要包含以下列(按维度扩展): +脚本会从公开视觉数据集里选两个类别,分别作为桥两端: -- `x0_0, x0_1, ...`:起点样本 -- `x1_0, x1_1, ...`:终点样本 +- `x0`: `source_class` 类别样本 +- `x1`: `target_class` 类别样本 -例如二维数据: +然后随机配对,形成 `(x0, x1)` 训练样本。 -```text -x0_0,x0_1,x1_0,x1_1 --1.2,0.3,2.7,1.1 -... -``` +支持数据集: -## 3. 训练方式 +- `mnist` +- `fashionmnist` +- `cifar10` -### 3.1 使用内置合成数据直接训练 +## 3. 训练示例 + +### 3.1 MNIST: 数字 1 -> 数字 7 ```bash -python dsb_torch.py --epochs 20 --batch_size 256 +python dsb_torch.py \ + --dataset mnist \ + --source_class 1 \ + --target_class 7 \ + --max_pairs 5000 \ + --epochs 20 ``` -### 3.2 先导出合成 CSV,再从 CSV 读取训练 +### 3.2 CIFAR10: airplane(0) -> ship(8) ```bash -python dsb_torch.py --make_synth_csv data/bridge.csv --n_samples 5000 --dim 2 -python dsb_torch.py --csv data/bridge.csv --epochs 20 +python dsb_torch.py \ + --dataset cifar10 \ + --source_class 0 \ + --target_class 8 \ + --max_pairs 5000 \ + --epochs 20 ``` ## 4. 输出指标(验证集) -每个 epoch 会输出: +每个 epoch 输出: -- `val_loss`:验证总损失 -- `endpoint_mse`:端点重建 MSE(x0/x1) -- `endpoint_mae`:端点重建 MAE(x0/x1) +- `val_loss`:验证总损失(端点重建平方误差总和 / 样本数) +- `endpoint_mse`:端点重建 MSE(`x0/x1`) +- `endpoint_mae`:端点重建 MAE(`x0/x1`) - `path_mse`:中间时刻路径均值重建 MSE -示例输出: +示例: ```text -Epoch 005 | train_loss=0.123456 | val_loss=0.118901 | endpoint_mse=0.029725 | endpoint_mae=0.128000 | path_mse=0.012345 +Epoch 005 | train_loss=0.111111 | val_loss=85.432100 | endpoint_mse=0.054321 | endpoint_mae=0.151200 | path_mse=0.020345 ``` ## 5. 方法说明(简化版) -- 对配对样本 `(x0, x1)`,随机采样时间 `t ~ U(0,1)`。 -- 构造桥上的噪声点: +- 对配对样本 `(x0, x1)`,随机采样时间 `t ~ U(0,1)` +- 构造桥噪声点: `x_t = (1-t)x0 + t x1 + sigma * sqrt(t(1-t)) * eps` -- 模型输入 `x_t, t`,预测 `x0_hat, x1_hat`。 -- 训练损失:`MSE(x0_hat, x0) + MSE(x1_hat, x1)`。 +- 模型输入 `(x_t, t)`,输出 `(x0_hat, x1_hat)` +- 训练目标: + + `MSE(x0_hat, x0) + MSE(x1_hat, x1)` -这是一个教学用最小示例,便于后续替换为更标准的 score/势函数参数化版本。 +> 说明:本实现是教学用最小版本,用来演示“公开数据集输入 + DSB 训练/验证流程”。 diff --git a/dsb_torch.py b/dsb_torch.py index 53a317a..64f70e9 100644 --- a/dsb_torch.py +++ b/dsb_torch.py @@ -1,40 +1,22 @@ import argparse -import math -import os from dataclasses import dataclass from typing import Tuple -import numpy as np -import pandas as pd import torch from torch import nn from torch.utils.data import DataLoader, Dataset, random_split +from torchvision import datasets, transforms class BridgePairDataset(Dataset): - """Dataset for paired Schrödinger-bridge endpoints (x0, x1).""" + """Paired endpoints (x0, x1) for Schrödinger bridge training.""" def __init__(self, x0: torch.Tensor, x1: torch.Tensor): if x0.shape != x1.shape: - raise ValueError(f"x0 and x1 must have the same shape, got {x0.shape} and {x1.shape}") + raise ValueError(f"x0 and x1 must have same shape, got {x0.shape} and {x1.shape}") self.x0 = x0.float() self.x1 = x1.float() - @classmethod - def from_csv(cls, csv_path: str) -> "BridgePairDataset": - df = pd.read_csv(csv_path) - x0_cols = [c for c in df.columns if c.startswith("x0_")] - x1_cols = [c for c in df.columns if c.startswith("x1_")] - if not x0_cols or not x1_cols: - raise ValueError( - "CSV must contain x0_* and x1_* columns, e.g. x0_0,x0_1,...,x1_0,x1_1,..." - ) - x0_cols.sort() - x1_cols.sort() - x0 = torch.tensor(df[x0_cols].to_numpy(), dtype=torch.float32) - x1 = torch.tensor(df[x1_cols].to_numpy(), dtype=torch.float32) - return cls(x0, x1) - def __len__(self) -> int: return self.x0.shape[0] @@ -42,33 +24,57 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: return self.x0[idx], self.x1[idx] -def make_synthetic_bridge_dataset(n_samples: int = 5000, dim: int = 2, seed: int = 42) -> BridgePairDataset: - """Build toy paired data: x0~N([-2,0],I), x1=R*x0 + shift + noise.""" - rng = np.random.default_rng(seed) - x0 = rng.normal(size=(n_samples, dim)).astype(np.float32) - x0[:, 0] -= 2.0 - - theta = math.pi / 3 - rot = np.array( - [ - [math.cos(theta), -math.sin(theta)], - [math.sin(theta), math.cos(theta)], - ], - dtype=np.float32, - ) - - if dim == 2: - x1 = x0 @ rot.T - else: - # for dim>2, rotate first 2 dims and keep others affine-shifted - x1 = x0.copy() - x1[:, :2] = x0[:, :2] @ rot.T - x1[:, 2:] = 0.7 * x0[:, 2:] +class PublicVisionBridgeDatasetBuilder: + """Build paired endpoints from a public torchvision dataset by class-to-class pairing.""" - x1 += np.array([4.0] + [0.5] * (dim - 1), dtype=np.float32) - x1 += 0.1 * rng.normal(size=x1.shape).astype(np.float32) + _SUPPORTED = { + "mnist": datasets.MNIST, + "fashionmnist": datasets.FashionMNIST, + "cifar10": datasets.CIFAR10, + } - return BridgePairDataset(torch.from_numpy(x0), torch.from_numpy(x1)) + @classmethod + def build( + cls, + root: str, + dataset_name: str, + source_class: int, + target_class: int, + max_pairs: int, + seed: int, + train_split: bool = True, + ) -> BridgePairDataset: + dataset_name = dataset_name.lower() + if dataset_name not in cls._SUPPORTED: + raise ValueError(f"Unsupported dataset '{dataset_name}'. Choose from {list(cls._SUPPORTED.keys())}") + + transform = transforms.ToTensor() + dataset_cls = cls._SUPPORTED[dataset_name] + ds = dataset_cls(root=root, train=train_split, download=True, transform=transform) + + imgs = [] + labels = [] + for img, label in ds: + imgs.append(img) + labels.append(int(label)) + + x = torch.stack(imgs, dim=0) + y = torch.tensor(labels, dtype=torch.long) + + src_idx = (y == source_class).nonzero(as_tuple=False).squeeze(1) + tgt_idx = (y == target_class).nonzero(as_tuple=False).squeeze(1) + + if src_idx.numel() == 0 or tgt_idx.numel() == 0: + raise ValueError("source_class or target_class has no samples in selected split") + + n_pairs = min(src_idx.numel(), tgt_idx.numel(), max_pairs) + gen = torch.Generator().manual_seed(seed) + src_perm = src_idx[torch.randperm(src_idx.numel(), generator=gen)[:n_pairs]] + tgt_perm = tgt_idx[torch.randperm(tgt_idx.numel(), generator=gen)[:n_pairs]] + + x0 = x[src_perm].flatten(start_dim=1) + x1 = x[tgt_perm].flatten(start_dim=1) + return BridgePairDataset(x0, x1) class TimeEmbedding(nn.Module): @@ -77,12 +83,8 @@ def __init__(self, emb_dim: int): self.emb_dim = emb_dim def forward(self, t: torch.Tensor) -> torch.Tensor: - """Sinusoidal time embedding.""" - # t: [B,1] - half_dim = self.emb_dim // 2 - freqs = torch.exp( - torch.linspace(0, math.log(1000), half_dim, device=t.device, dtype=t.dtype) - ) + half = self.emb_dim // 2 + freqs = torch.exp(torch.linspace(0, torch.log(torch.tensor(1000.0)), half, device=t.device, dtype=t.dtype)) args = t * freqs.unsqueeze(0) emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) if self.emb_dim % 2 == 1: @@ -91,9 +93,7 @@ def forward(self, t: torch.Tensor) -> torch.Tensor: class BridgeDenoiser(nn.Module): - """Given noisy bridge point x_t and time t, predict (x0, x1).""" - - def __init__(self, data_dim: int, hidden_dim: int = 128, time_dim: int = 32): + def __init__(self, data_dim: int, hidden_dim: int = 256, time_dim: int = 64): super().__init__() self.time_emb = TimeEmbedding(time_dim) self.net = nn.Sequential( @@ -109,135 +109,122 @@ def __init__(self, data_dim: int, hidden_dim: int = 128, time_dim: int = 32): def forward(self, x_t: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: te = self.time_emb(t) pred = self.net(torch.cat([x_t, te], dim=-1)) - x0_hat, x1_hat = torch.chunk(pred, chunks=2, dim=-1) - return x0_hat, x1_hat + return torch.chunk(pred, chunks=2, dim=-1) @dataclass class TrainConfig: - epochs: int = 50 - batch_size: int = 256 + epochs: int = 20 + batch_size: int = 128 lr: float = 1e-3 - sigma: float = 0.25 + sigma: float = 0.20 val_ratio: float = 0.2 seed: int = 42 -def sample_noisy_bridge(x0: torch.Tensor, x1: torch.Tensor, sigma: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def sample_noisy_bridge(x0: torch.Tensor, x1: torch.Tensor, sigma: float) -> Tuple[torch.Tensor, torch.Tensor]: bsz = x0.shape[0] t = torch.rand(bsz, 1, device=x0.device) eps = torch.randn_like(x0) var = torch.clamp(t * (1.0 - t), min=1e-4) x_t = (1.0 - t) * x0 + t * x1 + sigma * torch.sqrt(var) * eps - return x_t, t, eps + return x_t, t def compute_metrics(model: nn.Module, loader: DataLoader, sigma: float, device: torch.device) -> dict: model.eval() - mse = nn.MSELoss(reduction="sum") - mae = nn.L1Loss(reduction="sum") + mse_sum = nn.MSELoss(reduction="sum") + mae_sum = nn.L1Loss(reduction="sum") total = 0 - loss_sum = 0.0 - endpoint_mse = 0.0 - endpoint_mae = 0.0 - path_mse = 0.0 + endpoint_mse_sum = 0.0 + endpoint_mae_sum = 0.0 + path_mse_sum = 0.0 with torch.no_grad(): for x0, x1 in loader: x0 = x0.to(device) x1 = x1.to(device) - x_t, t, _ = sample_noisy_bridge(x0, x1, sigma) + x_t, t = sample_noisy_bridge(x0, x1, sigma) x0_hat, x1_hat = model(x_t, t) - loss = mse(x0_hat, x0) + mse(x1_hat, x1) mu_true = (1.0 - t) * x0 + t * x1 mu_hat = (1.0 - t) * x0_hat + t * x1_hat bsz = x0.shape[0] total += bsz - loss_sum += loss.item() - endpoint_mse += loss.item() - endpoint_mae += (mae(x0_hat, x0) + mae(x1_hat, x1)).item() - path_mse += mse(mu_hat, mu_true).item() + endpoint_mse_sum += (mse_sum(x0_hat, x0) + mse_sum(x1_hat, x1)).item() + endpoint_mae_sum += (mae_sum(x0_hat, x0) + mae_sum(x1_hat, x1)).item() + path_mse_sum += mse_sum(mu_hat, mu_true).item() - n_dims = loader.dataset[0][0].numel() * 2 + dim = loader.dataset[0][0].numel() return { - "val_loss": loss_sum / total, - "endpoint_mse": endpoint_mse / (total * n_dims), - "endpoint_mae": endpoint_mae / (total * n_dims), - "path_mse": path_mse / (total * (n_dims // 2)), + "val_loss": endpoint_mse_sum / total, + "endpoint_mse": endpoint_mse_sum / (total * 2 * dim), + "endpoint_mae": endpoint_mae_sum / (total * 2 * dim), + "path_mse": path_mse_sum / (total * dim), } def train(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, cfg: TrainConfig, device: torch.device): - optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr) + opt = torch.optim.Adam(model.parameters(), lr=cfg.lr) mse = nn.MSELoss() for epoch in range(1, cfg.epochs + 1): model.train() - train_loss = 0.0 - count = 0 + total_loss = 0.0 + total = 0 for x0, x1 in train_loader: x0 = x0.to(device) x1 = x1.to(device) - x_t, t, _ = sample_noisy_bridge(x0, x1, cfg.sigma) - + x_t, t = sample_noisy_bridge(x0, x1, cfg.sigma) x0_hat, x1_hat = model(x_t, t) loss = mse(x0_hat, x0) + mse(x1_hat, x1) - optimizer.zero_grad() + opt.zero_grad() loss.backward() - optimizer.step() + opt.step() - train_loss += loss.item() * x0.shape[0] - count += x0.shape[0] + bsz = x0.shape[0] + total_loss += loss.item() * bsz + total += bsz metrics = compute_metrics(model, val_loader, cfg.sigma, device) print( - f"Epoch {epoch:03d} | train_loss={train_loss / count:.6f} | " + f"Epoch {epoch:03d} | train_loss={total_loss / total:.6f} | " f"val_loss={metrics['val_loss']:.6f} | endpoint_mse={metrics['endpoint_mse']:.6f} | " f"endpoint_mae={metrics['endpoint_mae']:.6f} | path_mse={metrics['path_mse']:.6f}" ) -def save_dataset_csv(dataset: BridgePairDataset, path: str): - os.makedirs(os.path.dirname(path) or ".", exist_ok=True) - dim = dataset.x0.shape[1] - data = {} - for i in range(dim): - data[f"x0_{i}"] = dataset.x0[:, i].numpy() - for i in range(dim): - data[f"x1_{i}"] = dataset.x1[:, i].numpy() - pd.DataFrame(data).to_csv(path, index=False) - print(f"Saved synthetic dataset to {path}") - - def main(): - parser = argparse.ArgumentParser(description="Simple diffusion Schrödinger bridge (PyTorch)") - parser.add_argument("--csv", type=str, default="", help="Optional CSV path with x0_*/x1_* columns") - parser.add_argument("--make_synth_csv", type=str, default="", help="If set, save synthetic dataset to this CSV path and exit") - parser.add_argument("--n_samples", type=int, default=5000) - parser.add_argument("--dim", type=int, default=2) - parser.add_argument("--epochs", type=int, default=50) - parser.add_argument("--batch_size", type=int, default=256) + parser = argparse.ArgumentParser(description="Diffusion Schrödinger Bridge with public torchvision datasets") + parser.add_argument("--data_root", type=str, default="./data") + parser.add_argument("--dataset", type=str, default="mnist", choices=["mnist", "fashionmnist", "cifar10"]) + parser.add_argument("--source_class", type=int, default=1, help="x0 endpoint class id") + parser.add_argument("--target_class", type=int, default=7, help="x1 endpoint class id") + parser.add_argument("--max_pairs", type=int, default=5000) + parser.add_argument("--use_test_split", action="store_true", help="Use test split instead of train split") + parser.add_argument("--epochs", type=int, default=20) + parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--lr", type=float, default=1e-3) - parser.add_argument("--sigma", type=float, default=0.25) + parser.add_argument("--sigma", type=float, default=0.20) parser.add_argument("--val_ratio", type=float, default=0.2) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() torch.manual_seed(args.seed) - np.random.seed(args.seed) - - if args.csv: - dataset = BridgePairDataset.from_csv(args.csv) - else: - dataset = make_synthetic_bridge_dataset(args.n_samples, args.dim, args.seed) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if args.make_synth_csv: - save_dataset_csv(dataset, args.make_synth_csv) - return + dataset = PublicVisionBridgeDatasetBuilder.build( + root=args.data_root, + dataset_name=args.dataset, + source_class=args.source_class, + target_class=args.target_class, + max_pairs=args.max_pairs, + seed=args.seed, + train_split=not args.use_test_split, + ) n_val = int(len(dataset) * args.val_ratio) n_train = len(dataset) - n_val @@ -246,9 +233,7 @@ def main(): train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False) - data_dim = dataset[0][0].shape[0] - model = BridgeDenoiser(data_dim).to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) - + model = BridgeDenoiser(data_dim=dataset[0][0].numel()).to(device) cfg = TrainConfig( epochs=args.epochs, batch_size=args.batch_size, @@ -257,9 +242,12 @@ def main(): val_ratio=args.val_ratio, seed=args.seed, ) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") - print(f"Train/Val size: {n_train}/{n_val}") + print( + f"Dataset={args.dataset}, class {args.source_class} -> class {args.target_class}, " + f"pairs={len(dataset)}, train/val={n_train}/{n_val}" + ) train(model, train_loader, val_loader, cfg, device)