Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,83 @@
# DSB
# DSB

一个简单的 **Diffusion Schrödinger Bridge**(扩散薛定谔桥)PyTorch 示例实现,使用**公开数据集**(`torchvision`)作为输入,不再生成合成数据集。

包含内容:

- 数据集输入(MNIST / FashionMNIST / CIFAR10)
- 模型实现(时间条件 MLP)
- 模型训练
- 验证集评估指标

## 1. 环境依赖

```bash
pip install torch torchvision
```

## 2. 数据构造方式(公开数据集)

脚本会从公开视觉数据集里选两个类别,分别作为桥两端:

- `x0`: `source_class` 类别样本
- `x1`: `target_class` 类别样本

然后随机配对,形成 `(x0, x1)` 训练样本。

支持数据集:

- `mnist`
- `fashionmnist`
- `cifar10`

## 3. 训练示例

### 3.1 MNIST: 数字 1 -> 数字 7

```bash
python dsb_torch.py \
--dataset mnist \
--source_class 1 \
--target_class 7 \
--max_pairs 5000 \
--epochs 20
```

### 3.2 CIFAR10: airplane(0) -> ship(8)

```bash
python dsb_torch.py \
--dataset cifar10 \
--source_class 0 \
--target_class 8 \
--max_pairs 5000 \
--epochs 20
```

## 4. 输出指标(验证集)

每个 epoch 输出:

- `val_loss`:验证总损失(端点重建平方误差总和 / 样本数)
- `endpoint_mse`:端点重建 MSE(`x0/x1`)
- `endpoint_mae`:端点重建 MAE(`x0/x1`)
- `path_mse`:中间时刻路径均值重建 MSE

示例:

```text
Epoch 005 | train_loss=0.111111 | val_loss=85.432100 | endpoint_mse=0.054321 | endpoint_mae=0.151200 | path_mse=0.020345
```

## 5. 方法说明(简化版)

- 对配对样本 `(x0, x1)`,随机采样时间 `t ~ U(0,1)`
- 构造桥噪声点:

`x_t = (1-t)x0 + t x1 + sigma * sqrt(t(1-t)) * eps`
- 模型输入 `(x_t, t)`,输出 `(x0_hat, x1_hat)`
- 训练目标:

`MSE(x0_hat, x0) + MSE(x1_hat, x1)`

> 说明:本实现是教学用最小版本,用来演示“公开数据集输入 + DSB 训练/验证流程”。
255 changes: 255 additions & 0 deletions dsb_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
import argparse
from dataclasses import dataclass
from typing import Tuple

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets, transforms


class BridgePairDataset(Dataset):
"""Paired endpoints (x0, x1) for Schrödinger bridge training."""

def __init__(self, x0: torch.Tensor, x1: torch.Tensor):
if x0.shape != x1.shape:
raise ValueError(f"x0 and x1 must have same shape, got {x0.shape} and {x1.shape}")
self.x0 = x0.float()
self.x1 = x1.float()

def __len__(self) -> int:
return self.x0.shape[0]

def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
return self.x0[idx], self.x1[idx]


class PublicVisionBridgeDatasetBuilder:
"""Build paired endpoints from a public torchvision dataset by class-to-class pairing."""

_SUPPORTED = {
"mnist": datasets.MNIST,
"fashionmnist": datasets.FashionMNIST,
"cifar10": datasets.CIFAR10,
}

@classmethod
def build(
cls,
root: str,
dataset_name: str,
source_class: int,
target_class: int,
max_pairs: int,
seed: int,
train_split: bool = True,
) -> BridgePairDataset:
dataset_name = dataset_name.lower()
if dataset_name not in cls._SUPPORTED:
raise ValueError(f"Unsupported dataset '{dataset_name}'. Choose from {list(cls._SUPPORTED.keys())}")

transform = transforms.ToTensor()
dataset_cls = cls._SUPPORTED[dataset_name]
ds = dataset_cls(root=root, train=train_split, download=True, transform=transform)

imgs = []
labels = []
for img, label in ds:
imgs.append(img)
labels.append(int(label))

x = torch.stack(imgs, dim=0)
y = torch.tensor(labels, dtype=torch.long)

src_idx = (y == source_class).nonzero(as_tuple=False).squeeze(1)
tgt_idx = (y == target_class).nonzero(as_tuple=False).squeeze(1)

if src_idx.numel() == 0 or tgt_idx.numel() == 0:
raise ValueError("source_class or target_class has no samples in selected split")

n_pairs = min(src_idx.numel(), tgt_idx.numel(), max_pairs)
gen = torch.Generator().manual_seed(seed)
src_perm = src_idx[torch.randperm(src_idx.numel(), generator=gen)[:n_pairs]]
tgt_perm = tgt_idx[torch.randperm(tgt_idx.numel(), generator=gen)[:n_pairs]]

x0 = x[src_perm].flatten(start_dim=1)
x1 = x[tgt_perm].flatten(start_dim=1)
return BridgePairDataset(x0, x1)


class TimeEmbedding(nn.Module):
def __init__(self, emb_dim: int):
super().__init__()
self.emb_dim = emb_dim

def forward(self, t: torch.Tensor) -> torch.Tensor:
half = self.emb_dim // 2
freqs = torch.exp(torch.linspace(0, torch.log(torch.tensor(1000.0)), half, device=t.device, dtype=t.dtype))
args = t * freqs.unsqueeze(0)
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
if self.emb_dim % 2 == 1:
emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=-1)
return emb


class BridgeDenoiser(nn.Module):
def __init__(self, data_dim: int, hidden_dim: int = 256, time_dim: int = 64):
super().__init__()
self.time_emb = TimeEmbedding(time_dim)
self.net = nn.Sequential(
nn.Linear(data_dim + time_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, 2 * data_dim),
)

def forward(self, x_t: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
te = self.time_emb(t)
pred = self.net(torch.cat([x_t, te], dim=-1))
return torch.chunk(pred, chunks=2, dim=-1)


@dataclass
class TrainConfig:
epochs: int = 20
batch_size: int = 128
lr: float = 1e-3
sigma: float = 0.20
val_ratio: float = 0.2
seed: int = 42


def sample_noisy_bridge(x0: torch.Tensor, x1: torch.Tensor, sigma: float) -> Tuple[torch.Tensor, torch.Tensor]:
bsz = x0.shape[0]
t = torch.rand(bsz, 1, device=x0.device)
eps = torch.randn_like(x0)
var = torch.clamp(t * (1.0 - t), min=1e-4)
x_t = (1.0 - t) * x0 + t * x1 + sigma * torch.sqrt(var) * eps
return x_t, t


def compute_metrics(model: nn.Module, loader: DataLoader, sigma: float, device: torch.device) -> dict:
model.eval()
mse_sum = nn.MSELoss(reduction="sum")
mae_sum = nn.L1Loss(reduction="sum")
total = 0
endpoint_mse_sum = 0.0
endpoint_mae_sum = 0.0
path_mse_sum = 0.0

with torch.no_grad():
for x0, x1 in loader:
x0 = x0.to(device)
x1 = x1.to(device)
x_t, t = sample_noisy_bridge(x0, x1, sigma)
x0_hat, x1_hat = model(x_t, t)

mu_true = (1.0 - t) * x0 + t * x1
mu_hat = (1.0 - t) * x0_hat + t * x1_hat

bsz = x0.shape[0]
total += bsz
endpoint_mse_sum += (mse_sum(x0_hat, x0) + mse_sum(x1_hat, x1)).item()
endpoint_mae_sum += (mae_sum(x0_hat, x0) + mae_sum(x1_hat, x1)).item()
path_mse_sum += mse_sum(mu_hat, mu_true).item()

dim = loader.dataset[0][0].numel()
return {
"val_loss": endpoint_mse_sum / total,
"endpoint_mse": endpoint_mse_sum / (total * 2 * dim),
"endpoint_mae": endpoint_mae_sum / (total * 2 * dim),
"path_mse": path_mse_sum / (total * dim),
}


def train(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, cfg: TrainConfig, device: torch.device):
opt = torch.optim.Adam(model.parameters(), lr=cfg.lr)
mse = nn.MSELoss()

for epoch in range(1, cfg.epochs + 1):
model.train()
total_loss = 0.0
total = 0

for x0, x1 in train_loader:
x0 = x0.to(device)
x1 = x1.to(device)
x_t, t = sample_noisy_bridge(x0, x1, cfg.sigma)
x0_hat, x1_hat = model(x_t, t)
loss = mse(x0_hat, x0) + mse(x1_hat, x1)

opt.zero_grad()
loss.backward()
opt.step()

bsz = x0.shape[0]
total_loss += loss.item() * bsz
total += bsz

metrics = compute_metrics(model, val_loader, cfg.sigma, device)
print(
f"Epoch {epoch:03d} | train_loss={total_loss / total:.6f} | "
f"val_loss={metrics['val_loss']:.6f} | endpoint_mse={metrics['endpoint_mse']:.6f} | "
f"endpoint_mae={metrics['endpoint_mae']:.6f} | path_mse={metrics['path_mse']:.6f}"
)


def main():
parser = argparse.ArgumentParser(description="Diffusion Schrödinger Bridge with public torchvision datasets")
parser.add_argument("--data_root", type=str, default="./data")
parser.add_argument("--dataset", type=str, default="mnist", choices=["mnist", "fashionmnist", "cifar10"])
parser.add_argument("--source_class", type=int, default=1, help="x0 endpoint class id")
parser.add_argument("--target_class", type=int, default=7, help="x1 endpoint class id")
parser.add_argument("--max_pairs", type=int, default=5000)
parser.add_argument("--use_test_split", action="store_true", help="Use test split instead of train split")
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--sigma", type=float, default=0.20)
parser.add_argument("--val_ratio", type=float, default=0.2)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()

torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = PublicVisionBridgeDatasetBuilder.build(
root=args.data_root,
dataset_name=args.dataset,
source_class=args.source_class,
target_class=args.target_class,
max_pairs=args.max_pairs,
seed=args.seed,
train_split=not args.use_test_split,
)

n_val = int(len(dataset) * args.val_ratio)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(args.seed))

train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False)

model = BridgeDenoiser(data_dim=dataset[0][0].numel()).to(device)
cfg = TrainConfig(
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
sigma=args.sigma,
val_ratio=args.val_ratio,
seed=args.seed,
)

print(f"Using device: {device}")
print(
f"Dataset={args.dataset}, class {args.source_class} -> class {args.target_class}, "
f"pairs={len(dataset)}, train/val={n_train}/{n_val}"
)
train(model, train_loader, val_loader, cfg, device)


if __name__ == "__main__":
main()