diff --git a/model/UNet.py b/model/UNet.py index 5f200d7..7ffef2f 100644 --- a/model/UNet.py +++ b/model/UNet.py @@ -124,6 +124,40 @@ def sigma2(gamma_x): def alpha(gamma_x): return np.sqrt(1 - sigma2(gamma_x)) +# Sample function +def sample_step(self, rng, i, T, z_t, conditioning, guidance_weight=0.): + # rng_body = jax.random.fold_in(rng, i) + # eps = jax.random.normal(rng_body, z_t.shape) + # returns Generator object that manages state and generates the random bits, + # which are then transformed into random values from useful distributions + rng = np.random.default_rng(i) + eps = rng.standard_normal(z_t.shape) + t = (T - i)/T + s = (T - i - 1) / T + + g_s = self.gamma(s) + g_t = self.gamma(t) + + cond = self.embedding_vectors(conditioning) + + eps_hat_cond = self.score_model( + z_t, + g_t * np.ones((z_t.shape[0],), z_t.dtype), + cond,) + + eps_hat_uncond = self.score_model( + z_t, + g_t * np.ones((z_t.shape[0],), z_t.dtype), + cond * 0.,) + eps_hat = (1. + guidance_weight) * eps_hat_cond - guidance_weight * eps_hat_uncond + + + a = nn.sigmoid(g_s) + b = nn.sigmoid(g_t) + c = -np.expm1(g_t - g_s) + sigma_t = np.sqrt(sigma2(g_t)) + z_s = np.sqrt(a / b) * (z_t - sigma_t * c * eps_hat) + np.sqrt((1. - a) * c) * eps + return z_s class ResNet(nn.Module): def __init__(self, in_ch, out_ch, num_blocks=4, num_layers=4, num_filters=64, kernel_size=3, stride=1, padding=1,