From 3e4e3459faf67c74319895103eddf91b57c032fd Mon Sep 17 00:00:00 2001 From: austinroose Date: Wed, 11 Jan 2023 15:53:47 +0100 Subject: [PATCH 1/3] add base function for sampling --- model/UNet.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/model/UNet.py b/model/UNet.py index 5f200d7..6973e3b 100644 --- a/model/UNet.py +++ b/model/UNet.py @@ -124,6 +124,36 @@ def sigma2(gamma_x): def alpha(gamma_x): return np.sqrt(1 - sigma2(gamma_x)) +# Sample function +def sample_step(self, rng, i, T, z_t, conditioning, guidance_weight=0.): + rng_body = jax.random.fold_in(rng, i) + eps = random.normal(rng_body, z_t.shape) + t = (T - i)/T + s = (T - i - 1) / T + + g_s = self.gamma(s) + g_t = self.gamma(t) + + cond = self.embedding_vectors(conditioning) + + eps_hat_cond = self.score_model( + z_t, + g_t * np.ones((z_t.shape[0],), z_t.dtype), + cond,) + + eps_hat_uncond = self.score_model( + z_t, + g_t * np.ones((z_t.shape[0],), z_t.dtype), + cond * 0.,) + eps_hat = (1. + guidance_weight) * eps_hat_cond - guidance_weight * eps_hat_uncond + + + a = nn.sigmoid(g_s) + b = nn.sigmoid(g_t) + c = -np.expm1(g_t - g_s) + sigma_t = np.sqrt(sigma2(g_t)) + z_s = np.sqrt(a / b) * (z_t - sigma_t * c * eps_hat) + np.sqrt((1. - a) * c) * eps + return z_s class ResNet(nn.Module): def __init__(self, in_ch, out_ch, num_blocks=4, num_layers=4, num_filters=64, kernel_size=3, stride=1, padding=1, From f75e7ea1e5ea1eaafdaa26331856937aed204e53 Mon Sep 17 00:00:00 2001 From: austinroose Date: Wed, 11 Jan 2023 16:02:45 +0100 Subject: [PATCH 2/3] imports --- model/UNet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model/UNet.py b/model/UNet.py index 6973e3b..5aa2695 100644 --- a/model/UNet.py +++ b/model/UNet.py @@ -12,6 +12,7 @@ import torch.nn.functional as F import torchvision import numpy as np +import jax # https://github.com/g2archie/UNet-MRI-Reconstruction @@ -127,7 +128,7 @@ def alpha(gamma_x): # Sample function def sample_step(self, rng, i, T, z_t, conditioning, guidance_weight=0.): rng_body = jax.random.fold_in(rng, i) - eps = random.normal(rng_body, z_t.shape) + eps = jax.random.normal(rng_body, z_t.shape) t = (T - i)/T s = (T - i - 1) / T From cd3671fa1000813696335d7364b1eb3b33b340c0 Mon Sep 17 00:00:00 2001 From: Austin Roose Date: Thu, 12 Jan 2023 18:26:31 +0100 Subject: [PATCH 3/3] use numpy instead of jax to get values from normal distribution --- model/UNet.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/model/UNet.py b/model/UNet.py index 5aa2695..7ffef2f 100644 --- a/model/UNet.py +++ b/model/UNet.py @@ -12,7 +12,6 @@ import torch.nn.functional as F import torchvision import numpy as np -import jax # https://github.com/g2archie/UNet-MRI-Reconstruction @@ -127,8 +126,12 @@ def alpha(gamma_x): # Sample function def sample_step(self, rng, i, T, z_t, conditioning, guidance_weight=0.): - rng_body = jax.random.fold_in(rng, i) - eps = jax.random.normal(rng_body, z_t.shape) + # rng_body = jax.random.fold_in(rng, i) + # eps = jax.random.normal(rng_body, z_t.shape) + # returns Generator object that manages state and generates the random bits, + # which are then transformed into random values from useful distributions + rng = np.random.default_rng(i) + eps = rng.standard_normal(z_t.shape) t = (T - i)/T s = (T - i - 1) / T