Add simple PyTorch diffusion Schrödinger bridge training example#2
Open
Tumb1eweed wants to merge 2 commits intomainfrom
Open
Add simple PyTorch diffusion Schrödinger bridge training example#2Tumb1eweed wants to merge 2 commits intomainfrom
Tumb1eweed wants to merge 2 commits intomainfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Description
dsb_torch.py,包含配对数据集类BridgePairDataset、合成数据生成器make_synthetic_bridge_dataset、以及 CSV 导入/导出接口。TimeEmbedding与时间条件 MLP 模型BridgeDenoiser,模型输入为(x_t, t),输出为(x0_hat, x1_hat)。sample_noisy_bridge、训练循环train)以及验证指标计算compute_metrics,并在每 epoch 打印val_loss、endpoint_mse、endpoint_mae和path_mse。README.md,补充依赖、数据格式、使用示例命令和指标说明(包含--make_synth_csv用法)。Testing
python -m py_compile dsb_torch.py(通过)。python dsb_torch.py --epochs 2 --n_samples 1000 --batch_size 128)时出现ModuleNotFoundError: No module named 'numpy'(失败,依赖缺失)。pip install numpy pandas --quiet)在当前环境中因网络/代理限制失败(失败,无法从 PyPI 拉取包)。Codex Task