diff --git a/examples/text_to_img.py b/examples/text_to_img.py index 5abaa5f..8ca9542 100644 --- a/examples/text_to_img.py +++ b/examples/text_to_img.py @@ -27,10 +27,10 @@ def main(): create_workdir(args.workdir) solver_config = munchify({'num_sampling': args.NFE }) - callback = ComposeCallback(workdir=args.workdir, - frequency=1, - callbacks=["draw_noisy", 'draw_tweedie']) - # callback = None + # callback = ComposeCallback(workdir=args.workdir, + # frequency=1, + # callbacks=["draw_noisy", 'draw_tweedie']) + callback = None if args.model == "sdxl" or args.model == "sdxl_lightning": diff --git a/latent_diffusion.py b/latent_diffusion.py index 9f43448..93d02e0 100644 --- a/latent_diffusion.py +++ b/latent_diffusion.py @@ -3,7 +3,7 @@ Forward operators follow DPS and DDRM/DDNM. """ -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple import torch from diffusers import DDIMScheduler, StableDiffusionPipeline @@ -26,6 +26,9 @@ def get_solver(name: str, **kwargs): return __SOLVER__[name](**kwargs) ######################## +# Helper functions +# taken from comfyui +######################## def get_ancestral_step(sigma_from, sigma_to, eta=1.): """Calculates the noise level (sigma_down) to step down to and the amount @@ -36,19 +39,11 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.): sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 return sigma_down, sigma_up - def append_zero(x): return torch.cat([x, x.new_zeros([1])]) - -def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): - """Constructs the noise schedule of Karras et al. (2022).""" - ramp = torch.linspace(0, 1, n, device=device) - min_inv_rho = sigma_min ** (1 / rho) - max_inv_rho = sigma_max ** (1 / rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return append_zero(sigmas).to(device) - +######################## +# Base classes ######################## class StableDiffusion(): @@ -57,8 +52,19 @@ def __init__(self, model_key:str="runwayml/stable-diffusion-v1-5", device: Optional[torch.device]=None, **kwargs): - self.device = device + """ + The base class of LDM-based solvers for VP sampling. + We load pre-trained VAE, text-encoder, and U-Net models from diffusers. + Also, compute pre-defined coefficients. + args: + solver_config (Dict): solver configurations (e.g. NFE) + model_key (str): model key for loading pre-trained models + device (torch.device): device + **kwargs: additional arguments + """ + # pre-traiend model loading + self.device = device self.dtype = kwargs.get("pipe_dtype", torch.float16) pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=self.dtype).to(device) self.vae = pipe.vae @@ -66,12 +72,10 @@ def __init__(self, self.text_encoder = pipe.text_encoder self.unet = pipe.unet + # load scheduler self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") - self.total_alphas = self.scheduler.alphas_cumprod.clone() - - self.sigmas = (1-self.total_alphas).sqrt() / self.total_alphas.sqrt() - self.log_sigmas = self.sigmas.log() - + + # time discretization total_timesteps = len(self.scheduler.timesteps) self.scheduler.set_timesteps(solver_config.num_sampling, device=device) self.skip = total_timesteps // solver_config.num_sampling @@ -83,6 +87,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: self.sample(*args, **kwargs) def sample(self, *args: Any, **kwargs: Any) -> Any: + """ + The method that distinguishes each solver. + """ raise NotImplementedError("Solver must implement sample() method.") def alpha(self, t): @@ -90,7 +97,7 @@ def alpha(self, t): return at @torch.no_grad() - def get_text_embed(self, null_prompt, prompt): + def get_text_embed(self, null_prompt: str, prompt: str): """ Get text embedding. args: @@ -114,15 +121,19 @@ def get_text_embed(self, null_prompt, prompt): return null_text_embed, text_embed - def encode(self, x): + def encode(self, x: torch.Tensor): """ - xt -> zt + Encode image to latent features. + args: + x (torch.Tensor): image """ return self.vae.encode(x).latent_dist.sample() * 0.18215 - def decode(self, zt): + def decode(self, zt: torch.Tensor): """ - zt -> xt + Decode latent features to image. + args: + zt (torch.Tensor): latent """ zt = 1/0.18215 * zt img = self.vae.decode(zt).sample.float() @@ -184,7 +195,16 @@ def inversion(self, def initialize_latent(self, method: str='random', src_img: Optional[torch.Tensor]=None, + latent_size: tuple=(1, 4, 64, 64), **kwargs): + """ + Initialize latent features. + Simply, sample from Gaussian distribution or do inversion. + args: + method (str): initialization method + src_img (torch.Tensor): source image + **kwargs: additional arguments + """ if method == 'ddim': z = self.inversion(self.encode(src_img.to(self.dtype).to(self.device)), kwargs.get('uc'), @@ -196,43 +216,104 @@ def initialize_latent(self, kwargs.get('c'), cfg_guidance=1.0) elif method == 'random': - size = kwargs.get('latent_dim', (1, 4, 64, 64)) - z = torch.randn(size).to(self.device) + z = torch.randn(latent_size).to(self.device) + elif method == 'random_kdiffusion': - size = kwargs.get('latent_dim', (1, 4, 64, 64)) sigmas = kwargs.get('sigmas', [14.6146]) - z = torch.randn(size).to(self.device) + z = torch.randn(latent_size).to(self.device) z = z * (sigmas[0] ** 2 + 1) ** 0.5 else: raise NotImplementedError return z.requires_grad_() - - def timestep(self, sigma): + + def calculate_denoised(self, x: torch.Tensor, model_pred: torch.Tensor, alpha: torch.FloatTensor): + """ + Compute Tweedie's formula in VP sampling. + args: + x (torch.Tensor): noisy sample + model_pred (torch.Tensor): estimated noise + alpha (torch.FloatTensor): alpha + """ + return (x - (1-alpha).sqrt() * model_pred) / alpha.sqrt() + +class Kdiffusion(StableDiffusion): + def __init__(self, + solver_config: Dict, + model_key:str="runwayml/stable-diffusion-v1-5", + device: Optional[torch.device]=None, + **kwargs): + """ + Base class of LDM-based solvers based on (Karras et al 2022) + Contain methods to leveraging VP diffusion model for VE sampling. + For solvers like DPM and DPM++. + + args: + solver_config (Dict): solver configurations (e.g. NFE) + model_key (str): model key for loading pre-trained models + device (torch.device): device + **kwargs: additional arguments + """ + super().__init__(solver_config, model_key, device, **kwargs) + + # load scheduler once again, not saved to self + scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") + # convert alphas to sigmas (VP -> VE) + total_sigmas = (1-scheduler.alphas_cumprod).sqrt() / scheduler.alphas_cumprod.sqrt() + self.log_sigmas = total_sigmas.log() + self.sigma_min, self.sigma_max = total_sigmas.min(), total_sigmas.max() + + # get karras sigmas + self.k_sigmas = self.get_sigmas_karras(len(self.scheduler.timesteps), self.sigma_min, self.sigma_max) + + def get_sigmas_karras(self, n: int, sigma_min: float, sigma_max: float, rho: float=7., device: str='cpu'): + """Constructs the noise schedule of Karras et al. (2022).""" + ramp = torch.linspace(0, 1, n, device=device) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return append_zero(sigmas).to(device) + + def sigma_to_t(self, sigma: torch.FloatTensor): + """Convert sigma to timestep. (find the closest index)""" log_sigma = sigma.log() dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] - return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device) + return dists.abs().argmin(dim=0).view(sigma.shape) - def to_d(self, x, sigma, denoised): - '''converts a denoiser output to a Karras ODE derivative''' + def to_d(self, x: torch.Tensor, sigma: torch.FloatTensor, denoised: torch.Tensor): + ''' + converts a denoiser output to a Karras ODE derivative + args: + x (torch.Tensor): noisy sample + sigma (torch.FloatTensor): noise level + denoised (torch.Tensor): denoised sample + ''' return (x - denoised) / sigma.item() - def get_ancestral_step(self, sigma_from, sigma_to, eta=1.): - """Calculates the noise level (sigma_down) to step down to and the amount - of noise to add (sigma_up) when doing an ancestral sampling step.""" - if not eta: - return sigma_to, 0. - sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) - sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 - return sigma_down, sigma_up - - def calculate_input(self, x, sigma): + def calculate_input(self, x: torch.Tensor, sigma: float): return x / (sigma ** 2 + 1) ** 0.5 - def calculate_denoised(self, x, model_pred, sigma): + def calculate_denoised(self, x: torch.Tensor, model_pred: torch.Tensor, sigma: torch.FloatTensor): + """ + Compute Tweedie's formula in VE sampling. + args: + x (torch.Tensor): noisy sample + model_pred (torch.Tensor): estimated noise + alpha (torch.FloatTensor): alpha + """ return x - model_pred * sigma - def kdiffusion_x_to_denoised(self, x, sigma, uc, c, cfg_guidance, t): + def x_to_denoised(self, x, sigma, uc, c, cfg_guidance, t): + """ + Get noisy sample and compute denoised samples. + args: + x (torch.Tensor): noisy sample + sigma (float): noise level + uc (torch.Tensor): null-text embedding + c (torch.Tensor): text embedding + cfg_guidance (float): guidance scale + t (torch.Tensor): timestep + """ xc = self.calculate_input(x, sigma) noise_uc, noise_c = self.predict_noise(xc, t, uc, c) noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc) @@ -240,15 +321,16 @@ def kdiffusion_x_to_denoised(self, x, sigma, uc, c, cfg_guidance, t): uncond_denoised = self.calculate_denoised(x, noise_uc, sigma) return denoised, uncond_denoised + ########################################### -# Base version +# VP version samplers ########################################### @register_solver("ddim") -class BaseDDIM(StableDiffusion): +class DDIM(StableDiffusion): """ Basic DDIM solver for SD. - Useful for text-to-image generation + VP sampling. """ @torch.autocast(device_type='cuda', dtype=torch.float16) @@ -257,10 +339,6 @@ def sample(self, prompt=["",""], callback_fn=None, **kwargs): - """ - Main function that defines each solver. - This will generate samples without considering measurements. - """ # Text embedding uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) @@ -280,7 +358,7 @@ def sample(self, noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc) # tweedie - z0t = (zt - (1-at).sqrt() * noise_pred) / at.sqrt() + z0t = self.calculate_denoised(zt, noise_pred, at) # add noise zt = at_prev.sqrt() * z0t + (1-at_prev).sqrt() * noise_pred @@ -298,9 +376,121 @@ def sample(self, img = (img / 2 + 0.5).clamp(0, 1) return img.detach().cpu() +@register_solver("ddim_inversion") +class InversionDDIM(DDIM): + """ + Reconstruction after inversion. + Not for T2I generation. + """ + @torch.autocast(device_type='cuda', dtype=torch.float16) + def sample(self, + src_img: torch.Tensor, + cfg_guidance: float =7.5, + prompt: Tuple[str]=["",""], + callback_fn: Optional[Callable]=None, + **kwargs): + + # Text embedding + uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) + + # Initialize zT + zt = self.initialize_latent(method='ddim', + src_img=src_img, + uc=uc, + c=c, + cfg_guidance=cfg_guidance) + zt = zt.requires_grad_() + + # Sampling + pbar = tqdm(self.scheduler.timesteps, desc="SD") + for step, t in enumerate(pbar): + at = self.alpha(t) + at_prev = self.alpha(t - self.skip) + + with torch.no_grad(): + noise_uc, noise_c = self.predict_noise(zt, t, uc, c) + noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc) + + # tweedie + z0t = self.calculate_denoised(zt, noise_pred, at) + + # add noise + zt = at_prev.sqrt() * z0t + (1-at_prev).sqrt() * noise_pred + + if callback_fn is not None: + callback_kwargs = {'z0t': z0t.detach(), + 'zt': zt.detach(), + 'decode': self.decode} + callback_kwargs = callback_fn(step, t, callback_kwargs) + z0t = callback_kwargs["z0t"] + zt = callback_kwargs["zt"] + + # for the last step, do not add noise + img = self.decode(z0t) + img = (img / 2 + 0.5).clamp(0, 1) + return img.detach().cpu() + + +@register_solver("ddim_edit") +class EditWordSwapDDIM(InversionDDIM): + """ + Editing via WordSwap after inversion. + Useful for text-guided image editing. + Not for T2I generation. + """ + @torch.autocast(device_type='cuda', dtype=torch.float16) + def sample(self, + src_img: torch.Tensor, + cfg_guidance: float=7.5, + prompt: Tuple[str]=["","",""], + callback_fn: Optional[Callable]=None, + **kwargs): + + # Text embedding + uc, src_c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) + _, tgt_c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[2]) + + # Initialize zT + zt = self.initialize_latent(method='ddim', + src_img=src_img, + uc=uc, + c=src_c, + cfg_guidance=cfg_guidance) + # Sampling + pbar = tqdm(self.scheduler.timesteps, desc="DDIM-edit") + for step, t in enumerate(pbar): + at = self.alpha(t) + at_prev = self.alpha(t - self.skip) + + with torch.no_grad(): + noise_uc, noise_c = self.predict_noise(zt, t, uc, tgt_c) + noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc) + + # tweedie + z0t = self.calculate_denoised(zt, noise_pred, at) + + # add noise + zt = at_prev.sqrt() * z0t + (1-at_prev).sqrt() * noise_pred + + if callback_fn is not None: + callback_kwargs = {'z0t': z0t.detach(), + 'zt': zt.detach(), + 'decode': self.decode} + callback_kwargs = callback_fn(step, t, callback_kwargs) + z0t = callback_kwargs["z0t"] + zt = callback_kwargs["zt"] + + # for the last step, do not add noise + img = self.decode(z0t) + img = (img / 2 + 0.5).clamp(0, 1) + return img.detach().cpu() +########################################### +# VE version samplers (K-diffusion) +########################################### + @register_solver("euler") -class EulerCFGSolver(StableDiffusion): +class EulerCFGSolver(Kdiffusion): """ Karras Euler (VE casted) """ @@ -309,36 +499,31 @@ def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): # Text embedding uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) - # perpare alphas and sigmas - timesteps = reversed(torch.linspace(0, 1000, len(self.scheduler.timesteps)+1).long()) - # convert to karras sigma scheduler - total_sigmas = (1-self.total_alphas).sqrt() / self.total_alphas.sqrt() - sigmas = get_sigmas_karras(len(self.scheduler.timesteps), total_sigmas.min(), total_sigmas.max(), rho=7.) # initialize x = self.initialize_latent(method="random_kdiffusion", latent_dim=(1, 4, 64, 64), - sigmas=sigmas).to(torch.float16) + sigmas=self.k_sigmas).to(torch.float16) # Sampling - pbar = tqdm(self.scheduler.timesteps, desc="SD") + pbar = tqdm(self.scheduler.timesteps, desc="Euler") for i, _ in enumerate(pbar): - sigma = sigmas[i] - t = self.timestep(sigma).to(self.device) + sigma = self.k_sigmas[i] + t = self.sigma_to_t(sigma).to(self.device) with torch.no_grad(): - denoised, _ = self.kdiffusion_x_to_denoised(x, sigma, uc, c, cfg_guidance, t) + denoised, _ = self.x_to_denoised(x, sigma, uc, c, cfg_guidance, t) d = self.to_d(x, sigma, denoised) # Euler method - x = denoised + d * sigmas[i+1] + x = denoised + d * self.k_sigmas[i+1] if callback_fn is not None: callback_kwargs = {'z0t': denoised.detach(), 'zt': x.detach(), 'decode': self.decode} callback_kwargs = callback_fn(i, t, callback_kwargs) - z0t = callback_kwargs["z0t"] - zt = callback_kwargs["zt"] + denoised = callback_kwargs["z0t"] + x = callback_kwargs["zt"] # for the last step, do not add noise img = self.decode(denoised) @@ -347,7 +532,7 @@ def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): @register_solver("euler_a") -class EulerAncestralCFGSolver(StableDiffusion): +class EulerAncestralCFGSolver(Kdiffusion): """ Karras Euler (VE casted) + Ancestral sampling """ @@ -355,27 +540,24 @@ class EulerAncestralCFGSolver(StableDiffusion): def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): # Text embedding uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) - # convert to karras sigma scheduler - total_sigmas = (1-self.total_alphas).sqrt() / self.total_alphas.sqrt() - sigmas = get_sigmas_karras(len(self.scheduler.timesteps), total_sigmas.min(), total_sigmas.max(), rho=7.) # initialize x = self.initialize_latent(method="random_kdiffusion", latent_dim=(1, 4, 64, 64), - sigmas=sigmas).to(torch.float16) + sigmas=self.k_sigmas).to(torch.float16) # Sampling - pbar = tqdm(self.scheduler.timesteps, desc="SD") + pbar = tqdm(self.scheduler.timesteps, desc="Euler_a") for i, _ in enumerate(pbar): - sigma = sigmas[i] - t = self.timestep(sigma).to(self.device) - sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) + sigma = self.k_sigmas[i] + t = self.sigma_to_t(sigma).to(self.device) + sigma_down, sigma_up = get_ancestral_step(self.k_sigmas[i], self.k_sigmas[i + 1]) with torch.no_grad(): - denoised, _ = self.kdiffusion_x_to_denoised(x, sigma, uc, c, cfg_guidance, t) + denoised, _ = self.x_to_denoised(x, sigma, uc, c, cfg_guidance, t) # Euler method d = self.to_d(x, sigma, denoised) x = denoised + d * sigma_down - if sigmas[i + 1] > 0: + if self.k_sigmas[i + 1] > 0: x = x + torch.randn_like(x) * sigma_up if callback_fn is not None: @@ -383,45 +565,45 @@ def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): 'zt': x.detach(), 'decode': self.decode} callback_kwargs = callback_fn(i, t, callback_kwargs) + denoised = callback_kwargs["z0t"] + x = callback_kwargs["zt"] # for the last step, do not add noise - img = self.decode(denoised) + img = self.decode(x) img = (img / 2 + 0.5).clamp(0, 1) return img.detach().cpu() @register_solver("dpm++_2s_a") -class DPMpp2sAncestralCFGSolver(StableDiffusion): +class DPMpp2sAncestralCFGSolver(Kdiffusion): @torch.autocast(device_type='cuda', dtype=torch.float16) def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): t_fn = lambda sigma: sigma.log().neg() sigma_fn = lambda t: t.neg().exp() + # Text embedding uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) - # convert to karras sigma scheduler - total_sigmas = (1-self.total_alphas).sqrt() / self.total_alphas.sqrt() - sigmas = get_sigmas_karras(len(self.scheduler.timesteps), total_sigmas.min(), total_sigmas.max(), rho=7.) # initialize x = self.initialize_latent(method="random_kdiffusion", latent_dim=(1, 4, 64, 64), - sigmas=sigmas).to(torch.float16) + sigmas=self.k_sigmas).to(torch.float16) # Sampling - pbar = tqdm(self.scheduler.timesteps, desc="SD") + pbar = tqdm(self.scheduler.timesteps, desc="DPM++2s_a") for i, _ in enumerate(pbar): - sigma = sigmas[i] - new_t = self.timestep(sigma).to(self.device) + sigma = self.k_sigmas[i] + t_1 = self.sigma_to_t(sigma).to(self.device) with torch.no_grad(): - denoised, _ = self.kdiffusion_x_to_denoised(x, sigma, uc, c, cfg_guidance, new_t) + denoised, _ = self.x_to_denoised(x, sigma, uc, c, cfg_guidance, t_1) - sigma_down, sigma_up = self.get_ancestral_step(sigmas[i], sigmas[i + 1]) + sigma_down, sigma_up = get_ancestral_step(self.k_sigmas[i], self.k_sigmas[i + 1]) if sigma_down == 0: # Euler method - d = self.to_d(x, sigmas[i], denoised) + d = self.to_d(x, self.k_sigmas[i], denoised) x = denoised + d * sigma_down else: # DPM-Solver++(2S) - t, t_next = t_fn(sigmas[i]), t_fn(sigma_down) + t, t_next = t_fn(self.k_sigmas[i]), t_fn(sigma_down) r = 1 / 2 h = t_next - t s = t + r * h @@ -429,19 +611,20 @@ def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): with torch.no_grad(): sigma_s = sigma_fn(s) - t_2 = self.timestep(sigma_s).to(self.device) - denoised_2, _ = self.kdiffusion_x_to_denoised(x_2, sigma_s, uc, c, cfg_guidance, t_2) + t_2 = self.sigma_to_t(sigma_s).to(self.device) + denoised_2, _ = self.x_to_denoised(x_2, sigma_s, uc, c, cfg_guidance, t_2) - x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2 + x = denoised_2 - torch.exp(-h) * denoised_2 + (sigma_fn(t_next) / sigma_fn(t)) * x + # Noise addition - if sigmas[i + 1] > 0: + if self.k_sigmas[i + 1] > 0: x = x + torch.randn_like(x) * sigma_up if callback_fn is not None: callback_kwargs = { 'z0t': denoised.detach(), 'zt': x.detach(), 'decode': self.decode} - callback_kwargs = callback_fn(i, new_t, callback_kwargs) + callback_kwargs = callback_fn(i, t_1, callback_kwargs) denoised = callback_kwargs["z0t"] x = callback_kwargs["zt"] @@ -452,37 +635,33 @@ def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): @register_solver("dpm++_2m") -class DPMpp2mCFGSolver(StableDiffusion): +class DPMpp2mCFGSolver(Kdiffusion): @torch.autocast(device_type='cuda', dtype=torch.float16) def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): t_fn = lambda sigma: sigma.log().neg() - sigma_fn = lambda t: t.neg().exp() # Text embedding uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) - # convert to karras sigma scheduler - total_sigmas = (1-self.total_alphas).sqrt() / self.total_alphas.sqrt() - sigmas = get_sigmas_karras(len(self.scheduler.timesteps), total_sigmas.min(), total_sigmas.max(), rho=7.) # initialize x = self.initialize_latent(method="random_kdiffusion", latent_dim=(1, 4, 64, 64), - sigmas=sigmas).to(torch.float16) + sigmas=self.k_sigmas).to(torch.float16) old_denoised = None # buffer # Sampling - pbar = tqdm(self.scheduler.timesteps, desc="SD") + pbar = tqdm(self.scheduler.timesteps, desc="DPM++_2m") for i, _ in enumerate(pbar): - sigma = sigmas[i] - new_t = self.timestep(sigma).to(self.device) + sigma = self.k_sigmas[i] + t1 = self.sigma_to_t(sigma).to(self.device) with torch.no_grad(): - denoised, _ = self.kdiffusion_x_to_denoised(x, sigma, uc, c, cfg_guidance, new_t) + denoised, _ = self.x_to_denoised(x, sigma, uc, c, cfg_guidance, t1) # solve ODE one step - t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i+1]) + t, t_next = t_fn(self.k_sigmas[i]), t_fn(self.k_sigmas[i+1]) h = t_next - t - if old_denoised is None or sigmas[i+1] == 0: - x = denoised + self.to_d(x, sigmas[i], denoised) * sigmas[i+1] + if old_denoised is None or self.k_sigmas[i+1] == 0: + x = denoised + self.to_d(x, self.k_sigmas[i], denoised) * self.k_sigmas[i+1] else: - h_last = t - t_fn(sigmas[i-1]) + h_last = t - t_fn(self.k_sigmas[i-1]) r = h_last / h extra1 = -torch.exp(-h) * denoised - (-h).expm1() * (denoised - old_denoised) / (2*r) extra2 = torch.exp(-h) * x @@ -493,7 +672,7 @@ def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): callback_kwargs = { 'z0t': denoised.detach(), 'zt': x.detach(), 'decode': self.decode} - callback_kwargs = callback_fn(i, new_t, callback_kwargs) + callback_kwargs = callback_fn(i, t1, callback_kwargs) denoised = callback_kwargs["z0t"] x = callback_kwargs["zt"] @@ -503,17 +682,19 @@ def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): return img.detach().cpu() -@register_solver("ddim_inversion") -class InversionDDIM(BaseDDIM): +########################################### +# VP version samplers with CFG++ +########################################### + +@register_solver("ddim_cfg++") +class DDIMCFGpp(StableDiffusion): """ - Editing via WardSwap after inversion. - Useful for text-guided image editing. + DDIM solver for SD with CFG++. """ @torch.autocast(device_type='cuda', dtype=torch.float16) def sample(self, - src_img, cfg_guidance=7.5, - prompt=["","",""], + prompt=["",""], callback_fn=None, **kwargs): @@ -521,11 +702,7 @@ def sample(self, uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) # Initialize zT - zt = self.initialize_latent(method='ddim', - src_img=src_img, - uc=uc, - c=c, - cfg_guidance=cfg_guidance) + zt = self.initialize_latent() zt = zt.requires_grad_() # Sampling @@ -539,10 +716,10 @@ def sample(self, noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc) # tweedie - z0t = (zt - (1-at).sqrt() * noise_pred) / at.sqrt() + z0t = self.calculate_denoised(zt, noise_pred, at) # add noise - zt = at_prev.sqrt() * z0t + (1-at_prev).sqrt() * noise_pred + zt = at_prev.sqrt() * z0t + (1-at_prev).sqrt() * noise_uc if callback_fn is not None: callback_kwargs = {'z0t': z0t.detach(), @@ -556,47 +733,69 @@ def sample(self, img = self.decode(z0t) img = (img / 2 + 0.5).clamp(0, 1) return img.detach().cpu() - - -@register_solver("ddim_edit") -class EditWordSwapDDIM(InversionDDIM): + +@register_solver("ddim_inversion_cfg++") +class InversionDDIMCFGpp(DDIMCFGpp): """ - Editing via WordSwap after inversion. - Useful for text-guided image editing. + Reconstruction after inversion. + Not for T2I generation. """ + @torch.no_grad() + def inversion(self, + z0: torch.Tensor, + uc: torch.Tensor, + c: torch.Tensor, + cfg_guidance: float=1.0): + + # initialize z_0 + zt = z0.clone().to(self.device) + + # loop + pbar = tqdm(reversed(self.scheduler.timesteps), desc='DDIM Inversion') + for _, t in enumerate(pbar): + at = self.alpha(t) + at_prev = self.alpha(t-self.skip) + + noise_uc, noise_c = self.predict_noise(zt, t, uc, c) + noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc) + + z0t = (zt - (1-at_prev).sqrt() * noise_uc) / at_prev.sqrt() + zt = at.sqrt() * z0t + (1-at).sqrt() * noise_pred + + return zt + @torch.autocast(device_type='cuda', dtype=torch.float16) def sample(self, - src_img, - cfg_guidance=7.5, - prompt=["","",""], - callback_fn=None, + src_img: torch.Tensor, + cfg_guidance: float =7.5, + prompt: Tuple[str]=["",""], + callback_fn: Optional[Callable]=None, **kwargs): - # Text embedding - uc, src_c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) - _, tgt_c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[2]) + uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) # Initialize zT zt = self.initialize_latent(method='ddim', src_img=src_img, uc=uc, - c=src_c, + c=c, cfg_guidance=cfg_guidance) + # Sampling - pbar = tqdm(self.scheduler.timesteps, desc="DDIM-edit") + pbar = tqdm(self.scheduler.timesteps, desc="SD") for step, t in enumerate(pbar): at = self.alpha(t) at_prev = self.alpha(t - self.skip) with torch.no_grad(): - noise_uc, noise_c = self.predict_noise(zt, t, uc, tgt_c) + noise_uc, noise_c = self.predict_noise(zt, t, uc, c) noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc) # tweedie - z0t = (zt - (1-at).sqrt() * noise_pred) / at.sqrt() + z0t = self.calculate_denoised(zt, noise_pred, at) # add noise - zt = at_prev.sqrt() * z0t + (1-at_prev).sqrt() * noise_pred + zt = at_prev.sqrt() * z0t + (1-at_prev).sqrt() * noise_uc if callback_fn is not None: callback_kwargs = {'z0t': z0t.detach(), @@ -611,56 +810,43 @@ def sample(self, img = (img / 2 + 0.5).clamp(0, 1) return img.detach().cpu() - - - -########################################### -# CFG++ version -########################################### - -@register_solver("ddim_cfg++") -class BaseDDIMCFGpp(StableDiffusion): +@register_solver("ddim_edit_cfg++") +class EditWordSwapDDIMCFGpp(InversionDDIMCFGpp): """ - DDIM solver for SD with CFG++. - Useful for text-to-image generation + Editing via WordSwap after inversion. + Useful for text-guided image editing. + Not for T2I generation. """ - def __init__(self, - solver_config: Dict, - model_key:str="runwayml/stable-diffusion-v1-5", - device: Optional[torch.device]=None, - **kwargs): - super().__init__(solver_config, model_key, device, **kwargs) - @torch.autocast(device_type='cuda', dtype=torch.float16) def sample(self, - cfg_guidance=7.5, - prompt=["",""], - callback_fn=None, + src_img: torch.Tensor, + cfg_guidance: float=7.5, + prompt: Tuple[str]=["","",""], + callback_fn: Optional[Callable]=None, **kwargs): - """ - Main function that defines each solver. - This will generate samples without considering measurements. - """ # Text embedding - uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) + uc, src_c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) + _, tgt_c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[2]) # Initialize zT - zt = self.initialize_latent() - zt = zt.requires_grad_() - + zt = self.initialize_latent(method='ddim', + src_img=src_img, + uc=uc, + c=src_c, + cfg_guidance=cfg_guidance) # Sampling - pbar = tqdm(self.scheduler.timesteps, desc="SD") + pbar = tqdm(self.scheduler.timesteps, desc="DDIM-edit") for step, t in enumerate(pbar): at = self.alpha(t) at_prev = self.alpha(t - self.skip) with torch.no_grad(): - noise_uc, noise_c = self.predict_noise(zt, t, uc, c) + noise_uc, noise_c = self.predict_noise(zt, t, uc, tgt_c) noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc) # tweedie - z0t = (zt - (1-at).sqrt() * noise_pred) / at.sqrt() + z0t = self.calculate_denoised(zt, noise_pred, at) # add noise zt = at_prev.sqrt() * z0t + (1-at_prev).sqrt() * noise_uc @@ -677,54 +863,52 @@ def sample(self, img = self.decode(z0t) img = (img / 2 + 0.5).clamp(0, 1) return img.detach().cpu() - - + +############################################### +# VE version samplers (K-diffusion) with CFG++ +############################################### + @register_solver("euler_cfg++") -class EulerCFGppSolver(StableDiffusion): +class EulerCFGppSolver(Kdiffusion): @torch.autocast(device_type='cuda', dtype=torch.float16) def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): # Text embedding uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) - # perpare alphas and sigmas - timesteps = reversed(torch.linspace(0, 1000, len(self.scheduler.timesteps)+1).long()) - # convert to karras sigma scheduler - total_sigmas = (1-self.total_alphas).sqrt() / self.total_alphas.sqrt() - sigmas = get_sigmas_karras(len(self.scheduler.timesteps), total_sigmas.min(), total_sigmas.max(), rho=7.) # initialize x = self.initialize_latent(method="random_kdiffusion", latent_dim=(1, 4, 64, 64), - sigmas=sigmas).to(torch.float16) + sigmas=self.k_sigmas).to(torch.float16) # Sampling - pbar = tqdm(self.scheduler.timesteps, desc="SD") + pbar = tqdm(self.scheduler.timesteps, desc="Euler_cpp") for i, _ in enumerate(pbar): - sigma = sigmas[i] - t = self.timestep(sigma).to(self.device) + sigma = self.k_sigmas[i] + t = self.sigma_to_t(sigma).to(self.device) with torch.no_grad(): - denoised, uncond_denoised = self.kdiffusion_x_to_denoised(x, sigma, uc, c, cfg_guidance, t) + denoised, uncond_denoised = self.x_to_denoised(x, sigma, uc, c, cfg_guidance, t) d = self.to_d(x, sigma, uncond_denoised) # Euler method - x = denoised + d * sigmas[i+1] + x = denoised + d * self.k_sigmas[i+1] if callback_fn is not None: callback_kwargs = {'z0t': denoised.detach(), 'zt': x.detach(), 'decode': self.decode} callback_kwargs = callback_fn(i, t, callback_kwargs) - z0t = callback_kwargs["z0t"] - zt = callback_kwargs["zt"] + denoised = callback_kwargs["z0t"] + x = callback_kwargs["zt"] # for the last step, do not add noise - img = self.decode(denoised) + img = self.decode(x) img = (img / 2 + 0.5).clamp(0, 1) return img.detach().cpu() @register_solver("euler_a_cfg++") -class EulerAncestralCFGppSolver(StableDiffusion): +class EulerAncestralCFGppSolver(Kdiffusion): """ Karras Euler (VE casted) + Ancestral sampling """ @@ -732,26 +916,23 @@ class EulerAncestralCFGppSolver(StableDiffusion): def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): # Text embedding uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) - # convert to karras sigma scheduler - total_sigmas = (1-self.total_alphas).sqrt() / self.total_alphas.sqrt() - sigmas = get_sigmas_karras(len(self.scheduler.timesteps), total_sigmas.min(), total_sigmas.max(), rho=7.) # initialize x = self.initialize_latent(method="random_kdiffusion", latent_dim=(1, 4, 64, 64), - sigmas=sigmas).to(torch.float16) + sigmas=self.k_sigmas).to(torch.float16) # Sampling - pbar = tqdm(self.scheduler.timesteps, desc="SD") + pbar = tqdm(self.scheduler.timesteps, desc="Euler_a_cpp") for i, _ in enumerate(pbar): - sigma = sigmas[i] - t = self.timestep(sigma).to(self.device) - sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) + sigma = self.k_sigmas[i] + t = self.sigma_to_t(sigma).to(self.device) + sigma_down, sigma_up = get_ancestral_step(self.k_sigmas[i], self.k_sigmas[i + 1]) with torch.no_grad(): - denoised, uncond_denoised = self.kdiffusion_x_to_denoised(x, sigma, uc, c, cfg_guidance, t) + denoised, uncond_denoised = self.x_to_denoised(x, sigma, uc, c, cfg_guidance, t) d = self.to_d(x, sigma, uncond_denoised) # Euler method x = denoised + d * sigma_down - if sigmas[i + 1] > 0: + if self.k_sigmas[i + 1] > 0: x = x + torch.randn_like(x) * sigma_up if callback_fn is not None: @@ -759,45 +940,44 @@ def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): 'zt': x.detach(), 'decode': self.decode} callback_kwargs = callback_fn(i, t, callback_kwargs) + denoised = callback_kwargs["z0t"] + x = callback_kwargs["zt"] # for the last step, do not add noise - img = self.decode(denoised) + img = self.decode(x) img = (img / 2 + 0.5).clamp(0, 1) return img.detach().cpu() @register_solver("dpm++_2s_a_cfg++") -class DPMpp2sAncestralCFGppSolver(StableDiffusion): +class DPMpp2sAncestralCFGppSolver(Kdiffusion): @torch.autocast(device_type='cuda', dtype=torch.float16) def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): t_fn = lambda sigma: sigma.log().neg() sigma_fn = lambda t: t.neg().exp() # Text embedding uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) - # convert to karras sigma scheduler - total_sigmas = (1-self.total_alphas).sqrt() / self.total_alphas.sqrt() - sigmas = get_sigmas_karras(len(self.scheduler.timesteps), total_sigmas.min(), total_sigmas.max(), rho=7.) # initialize x = self.initialize_latent(method="random_kdiffusion", latent_dim=(1, 4, 64, 64), - sigmas=sigmas).to(torch.float16) + sigmas=self.k_sigmas).to(torch.float16) # Sampling - pbar = tqdm(self.scheduler.timesteps, desc="SD") + pbar = tqdm(self.scheduler.timesteps, desc="DPM++2s_a_cpp") for i, _ in enumerate(pbar): - sigma = sigmas[i] - new_t = self.timestep(sigma).to(self.device) + sigma = self.k_sigmas[i] + t_1 = self.sigma_to_t(sigma).to(self.device) with torch.no_grad(): - denoised, uncond_denoised = self.kdiffusion_x_to_denoised(x, sigma, uc, c, cfg_guidance, new_t) + denoised, uncond_denoised = self.x_to_denoised(x, sigma, uc, c, cfg_guidance, t_1) - sigma_down, sigma_up = self.get_ancestral_step(sigmas[i], sigmas[i + 1]) + sigma_down, sigma_up = get_ancestral_step(self.k_sigmas[i], self.k_sigmas[i + 1]) if sigma_down == 0: # Euler method - d = self.to_d(x, sigmas[i], uncond_denoised) + d = self.to_d(x, self.k_sigmas[i], uncond_denoised) x = denoised + d * sigma_down else: # DPM-Solver++(2S) - t, t_next = t_fn(sigmas[i]), t_fn(sigma_down) + t, t_next = t_fn(self.k_sigmas[i]), t_fn(sigma_down) r = 1 / 2 h = t_next - t s = t + r * h @@ -805,19 +985,19 @@ def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): with torch.no_grad(): sigma_s = sigma_fn(s) - t_2 = self.timestep(sigma_s).to(self.device) - denoised_2, uncond_denoised_2 = self.kdiffusion_x_to_denoised(x_2, sigma_s, uc, c, cfg_guidance, t_2) + t_2 = self.sigma_to_t(sigma_s).to(self.device) + denoised_2, uncond_denoised_2 = self.x_to_denoised(x_2, sigma_s, uc, c, cfg_guidance, t_2) x = denoised_2 - torch.exp(-h) * uncond_denoised_2 + (sigma_fn(t_next) / sigma_fn(t)) * x # Noise addition - if sigmas[i + 1] > 0: + if self.k_sigmas[i + 1] > 0: x = x + torch.randn_like(x) * sigma_up if callback_fn is not None: callback_kwargs = { 'z0t': denoised.detach(), 'zt': x.detach(), 'decode': self.decode} - callback_kwargs = callback_fn(i, new_t, callback_kwargs) + callback_kwargs = callback_fn(i, t_1, callback_kwargs) denoised = callback_kwargs["z0t"] x = callback_kwargs["zt"] @@ -828,37 +1008,33 @@ def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): @register_solver("dpm++_2m_cfg++") -class DPMpp2mCFGppSolver(StableDiffusion): +class DPMpp2mCFGppSolver(Kdiffusion): @torch.autocast(device_type='cuda', dtype=torch.float16) def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): t_fn = lambda sigma: sigma.log().neg() - sigma_fn = lambda t: t.neg().exp() # Text embedding uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) - # convert to karras sigma scheduler - total_sigmas = (1-self.total_alphas).sqrt() / self.total_alphas.sqrt() - sigmas = get_sigmas_karras(len(self.scheduler.timesteps), total_sigmas.min(), total_sigmas.max(), rho=7.) # initialize x = self.initialize_latent(method="random_kdiffusion", latent_dim=(1, 4, 64, 64), - sigmas=sigmas).to(torch.float16) + sigmas=self.k_sigmas).to(torch.float16) old_denoised = None # buffer # Sampling - pbar = tqdm(self.scheduler.timesteps, desc="SD") + pbar = tqdm(self.scheduler.timesteps, desc="DPM++_2m_cpp") for i, _ in enumerate(pbar): - sigma = sigmas[i] - new_t = self.timestep(sigma).to(self.device) + sigma = self.k_sigmas[i] + t_1 = self.sigma_to_t(sigma).to(self.device) with torch.no_grad(): - denoised, uncond_denoised = self.kdiffusion_x_to_denoised(x, sigma, uc, c, cfg_guidance, new_t) + denoised, uncond_denoised = self.x_to_denoised(x, sigma, uc, c, cfg_guidance, t_1) # solve ODE one step - t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i+1]) + t, t_next = t_fn(self.k_sigmas[i]), t_fn(self.k_sigmas[i+1]) h = t_next - t - if old_denoised is None or sigmas[i+1] == 0: - x = denoised + self.to_d(x, sigmas[i], uncond_denoised) * sigmas[i+1] + if old_denoised is None or self.k_sigmas[i+1] == 0: + x = denoised + self.to_d(x, self.k_sigmas[i], uncond_denoised) * self.k_sigmas[i+1] else: - h_last = t - t_fn(sigmas[i-1]) + h_last = t - t_fn(self.k_sigmas[i-1]) r = h_last / h extra1 = -torch.exp(-h) * uncond_denoised - (-h).expm1() * (denoised - old_denoised) / (2*r) extra2 = torch.exp(-h) * x @@ -869,7 +1045,7 @@ def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): callback_kwargs = { 'z0t': denoised.detach(), 'zt': x.detach(), 'decode': self.decode} - callback_kwargs = callback_fn(i, new_t, callback_kwargs) + callback_kwargs = callback_fn(i, t_1, callback_kwargs) denoised = callback_kwargs["z0t"] x = callback_kwargs["zt"] @@ -879,137 +1055,6 @@ def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs): return img.detach().cpu() -@register_solver("ddim_inversion_cfg++") -class InversionDDIMCFGpp(BaseDDIMCFGpp): - """ - Editing via WordSwap after inversion. - Useful for text-guided image editing. - """ - @torch.no_grad() - def inversion(self, - z0: torch.Tensor, - uc: torch.Tensor, - c: torch.Tensor, - cfg_guidance: float=1.0): - - # initialize z_0 - zt = z0.clone().to(self.device) - - # loop - pbar = tqdm(reversed(self.scheduler.timesteps), desc='DDIM Inversion') - for _, t in enumerate(pbar): - at = self.alpha(t) - at_prev = self.alpha(t-self.skip) - - noise_uc, noise_c = self.predict_noise(zt, t, uc, c) - noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc) - - z0t = (zt - (1-at_prev).sqrt() * noise_uc) / at_prev.sqrt() - zt = at.sqrt() * z0t + (1-at).sqrt() * noise_pred - - return zt - - @torch.autocast(device_type='cuda', dtype=torch.float16) - def sample(self, - src_img, - cfg_guidance=7.5, - prompt=["",""], - callback_fn=None, - **kwargs): - - # Text embedding - uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) - - # Initialize zT - zt = self.initialize_latent(method='ddim', - src_img=src_img, - uc=uc, - c=c, - cfg_guidance=cfg_guidance) - - # Sampling - pbar = tqdm(self.scheduler.timesteps, desc="SD") - for step, t in enumerate(pbar): - at = self.alpha(t) - at_prev = self.alpha(t - self.skip) - - with torch.no_grad(): - noise_uc, noise_c = self.predict_noise(zt, t, uc, c) - noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc) - - # tweedie - z0t = (zt - (1-at).sqrt() * noise_pred) / at.sqrt() - - # add noise - zt = at_prev.sqrt() * z0t + (1-at_prev).sqrt() * noise_uc - - if callback_fn is not None: - callback_kwargs = {'z0t': z0t.detach(), - 'zt': zt.detach(), - 'decode': self.decode} - callback_kwargs = callback_fn(step, t, callback_kwargs) - z0t = callback_kwargs["z0t"] - zt = callback_kwargs["zt"] - - # for the last step, do not add noise - img = self.decode(z0t) - img = (img / 2 + 0.5).clamp(0, 1) - return img.detach().cpu() - -@register_solver("ddim_edit_cfg++") -class EditWordSwapDDIMCFGpp(InversionDDIMCFGpp): - """ - Editing via WordSwap after inversion. - Useful for text-guided image editing. - """ - @torch.autocast(device_type='cuda', dtype=torch.float16) - def sample(self, - src_img, - cfg_guidance=7.5, - prompt=["","",""], - callback_fn=None, - **kwargs): - - # Text embedding - uc, src_c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1]) - _, tgt_c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[2]) - - # Initialize zT - zt = self.initialize_latent(method='ddim', - src_img=src_img, - uc=uc, - c=src_c, - cfg_guidance=cfg_guidance) - # Sampling - pbar = tqdm(self.scheduler.timesteps, desc="DDIM-edit") - for step, t in enumerate(pbar): - at = self.alpha(t) - at_prev = self.alpha(t - self.skip) - - with torch.no_grad(): - noise_uc, noise_c = self.predict_noise(zt, t, uc, tgt_c) - noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc) - - # tweedie - z0t = (zt - (1-at).sqrt() * noise_pred) / at.sqrt() - - # add noise - zt = at_prev.sqrt() * z0t + (1-at_prev).sqrt() * noise_uc - - if callback_fn is not None: - callback_kwargs = {'z0t': z0t.detach(), - 'zt': zt.detach(), - 'decode': self.decode} - callback_kwargs = callback_fn(step, t, callback_kwargs) - z0t = callback_kwargs["z0t"] - zt = callback_kwargs["zt"] - - # for the last step, do not add noise - img = self.decode(z0t) - img = (img / 2 + 0.5).clamp(0, 1) - return img.detach().cpu() - - ############################# if __name__ == "__main__": diff --git a/latent_sdxl.py b/latent_sdxl.py index d6f164e..560a24a 100644 --- a/latent_sdxl.py +++ b/latent_sdxl.py @@ -1,13 +1,19 @@ -from typing import Any, Optional, Tuple +""" +This module is for sampling with SDXL. +""" + + import os -from safetensors.torch import load_file +from typing import Any, Callable, Dict, Optional, Tuple import torch -from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler +from diffusers import (AutoencoderKL, DDIMScheduler, EulerDiscreteScheduler, + StableDiffusionXLPipeline, UNet2DConditionModel) from diffusers.models.attention_processor import (AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor) +from safetensors.torch import load_file from tqdm import tqdm ####### Factory ####### @@ -27,17 +33,47 @@ def get_solver(name: str, **kwargs): return __SOLVER__[name](**kwargs) ######################## +# Helper functions +# taken from comfyui +######################## + +def get_ancestral_step(sigma_from, sigma_to, eta=1.): + """Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step.""" + if not eta: + return sigma_to, 0. + sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_down, sigma_up + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + +######################## +# Base classes +######################## class SDXL(): def __init__(self, - solver_config: dict, + solver_config: Dict, model_key:str="stabilityai/stable-diffusion-xl-base-1.0", dtype=torch.float16, device='cuda'): - + """ + The base class of LDM-based solvers for VP sampling, especially for SDXL. + We load pre-trained VAE, text-encoder, and U-Net models from diffusers. + Also, compute pre-defined coefficients. + + args: + solver_config (Dict): solver configurations (e.g. NFE) + model_key (str): model key for loading pre-trained models + device (torch.device): device + **kwargs: additional arguments + """ + # pre-trained model loading self.device = device - pipe = StableDiffusionXLPipeline.from_pretrained(model_key, torch_dtype=dtype).to(device) self.dtype = dtype + pipe = StableDiffusionXLPipeline.from_pretrained(model_key, torch_dtype=self.dtype).to(device) # avoid overflow in float16 self.vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype).to(device) @@ -50,9 +86,11 @@ def __init__(self, self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.default_sample_size = self.unet.config.sample_size - - # sampling parameters + + # load scheduler self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") + + # time discretization self.total_alphas = self.scheduler.alphas_cumprod.clone() N_ts = len(self.scheduler.timesteps) self.scheduler.set_timesteps(solver_config.num_sampling, device=device) @@ -69,7 +107,15 @@ def alpha(self, t): return at @torch.no_grad() - def _text_embed(self, prompt, tokenizer, text_enc, clip_skip): + def _text_embed(self, prompt: str, tokenizer: Callable, text_enc: Callable, clip_skip: bool): + """ + embedding function of SDXL + args: + prompt (str): text prompt + tokenizer (Callable): loaded tokenizer + text_enc (Callable): loaded text encoder + clip_skip (bool): + """ text_inputs = tokenizer( prompt, padding='max_length', @@ -88,10 +134,25 @@ def _text_embed(self, prompt, tokenizer, text_enc, clip_skip): return prompt_embeds, pool_prompt_embeds @torch.no_grad() - def get_text_embed(self, null_prompt_1, prompt_1, null_prompt_2=None, prompt_2=None, clip_skip=None): + def get_text_embed(self, + null_prompt_1: str, + prompt_1: str, + null_prompt_2: Optional[str]=None, + prompt_2: Optional[str]=None, + clip_skip: Optional[bool]=None): ''' + Get text embedding. + + TODO: At this time, assume that batch_size = 1. We should extend the code to batch_size > 1. + + args: + null_prompt_1 (str): null text for text encoder 1 + prompt_1 (str): guidance text for text encoder 1 + null_prompt_2 (Optional[str]): null prompt for text encoder 2. None: use null_prompt_1. + prompt_2 (Optional[str]): prompt for text encoder 2. None: use prompt_1. + clip_skip (Optional[bool]): ''' # Encode the prompts # if prompt_2 is None, set same as prompt_1 @@ -143,11 +204,21 @@ def upcast_vae(self): self.vae.decoder.mid_block.to(dtype) @torch.no_grad() - def encode(self, x): + def encode(self, x: torch.Tensor): + """ + Encode image to latent features. + args: + x (torch.Tensor): image + """ return self.vae.encode(x).latent_dist.sample() * self.vae.config.scaling_factor # @torch.no_grad() - def decode(self, zt): + def decode(self, zt: torch.Tensor): + """ + Decode latent features to image. + args: + zt (torch.Tensor): latent + """ # make sure the VAE is in float32 mode, as it overflows in float16 # needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast @@ -159,7 +230,13 @@ def decode(self, zt): return image - def predict_noise(self, zt, t, uc, c, added_cond_kwargs): + def predict_noise(self, + zt: torch.Tensor, + t: torch.Tensor, + uc: Optional[torch.Tensor], + c: Optional[torch.Tensor], + added_cond_kwargs: Optional[Dict]): + t_in = t.unsqueeze(0) if uc is None: noise_c = self.unet(zt, t_in, encoder_hidden_states=c, @@ -179,7 +256,22 @@ def predict_noise(self, zt, t, uc, c, added_cond_kwargs): return noise_uc, noise_c - def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim): + def _get_add_time_ids(self, + original_size: Tuple[int, int], + crops_coords_top_left: Tuple[int, int], + target_size: Tuple[int, int], + dtype: str, + text_encoder_projection_dim: int): + """ + Create additional kwargs for SDXL. + Taken from diffusers pipeline. + args: + original_size (Tuple[int, int]): original size of the image + crops_coords_top_left (Tuple[int, int]): top-left coordinates of the crop + target_size (Tuple[int, int]): target size of the image + dtype (str): data type + text_encoder_projection_dim (int): projection dimension of the text encoder + """ add_time_ids = list(original_size+crops_coords_top_left+target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim @@ -263,8 +355,18 @@ def sample(self, def initialize_latent(self, method: str='random', src_img: Optional[torch.Tensor]=None, + latent_size: tuple=(1, 4, 128, 128), add_cond_kwargs: Optional[dict]=None, **kwargs): + """ + Initialize latent features. + Simply, sample from Gaussian distribution or do inversion. + args: + method (str): initialization method + src_img (torch.Tensor): source image + add_cond_kwargs (dict): additional conditional arguments + **kwargs: additional arguments + """ if method == 'ddim': assert src_img is not None, "src_img must be provided for inversion" z = self.inversion(self.encode(src_img.to(self.dtype).to(self.device)), @@ -280,8 +382,11 @@ def initialize_latent(self, 1.0, add_cond_kwargs) elif method == 'random': - size = kwargs.get('size', (1, 4, 128, 128)) - z = torch.randn(size).to(self.device) + z = torch.randn(latent_size).to(self.device) + elif method == "random_kdiffusion": + sigmas = kwargs.get('sigmas', [14.6146]) + z = torch.randn(latent_size).to(self.device) + z = z * (sigmas[0] ** 2 + 1) ** 0.5 else: raise NotImplementedError @@ -311,6 +416,15 @@ def inversion(self, z0, uc, c, cfg_guidance, add_cond_kwargs): def reverse_process(self, *args, **kwargs): raise NotImplementedError + def calculate_denoised(self, x: torch.Tensor, model_pred: torch.Tensor, alpha: torch.FloatTensor): + """ + Compute Tweedie's formula in VP sampling. + args: + x (torch.Tensor): noisy sample + model_pred (torch.Tensor): estimated noise + alpha (torch.FloatTensor): alpha + """ + return (x - (1-alpha).sqrt() * model_pred) / alpha.sqrt() class SDXLLightning(SDXL): def __init__(self, @@ -359,12 +473,101 @@ def __init__(self, self.scheduler.alphas_cumprod = torch.cat([torch.tensor([1.0]), self.scheduler.alphas_cumprod]).to(device) +class Kdiffusion(SDXL): + def __init__(self, + solver_config: Dict, + model_key:str="stabilityai/stable-diffusion-xl-base-1.0", + dtype=torch.float16, + device: Optional[torch.device]=None, + **kwargs): + """ + Base SDXL class of LDM-based solvers based on (Karras et al 2022) + Contain methods to leveraging VP diffusion model for VE sampling. + For solvers like DPM and DPM++. + + args: + solver_config (Dict): solver configurations (e.g. NFE) + model_key (str): model key for loading pre-trained models + device (torch.device): device + **kwargs: additional arguments + """ + super().__init__(solver_config, model_key, dtype, device, **kwargs) + + # load scheduler once again, not saved to self + scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") + # convert alphas to sigmas (VP -> VE) + total_sigmas = (1-scheduler.alphas_cumprod).sqrt() / scheduler.alphas_cumprod.sqrt() + self.log_sigmas = total_sigmas.log() + self.sigma_min, self.sigma_max = total_sigmas.min(), total_sigmas.max() + + # get karras sigmas + self.k_sigmas = self.get_sigmas_karras(len(self.scheduler.timesteps), self.sigma_min, self.sigma_max) + + def get_sigmas_karras(self, n: int, sigma_min: float, sigma_max: float, rho: float=7., device: str='cpu'): + """Constructs the noise schedule of Karras et al. (2022).""" + ramp = torch.linspace(0, 1, n, device=device) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return append_zero(sigmas).to(device) + + def sigma_to_t(self, sigma: torch.FloatTensor): + """Convert sigma to timestep. (find the closest index)""" + log_sigma = sigma.log() + dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) + + def to_d(self, x: torch.Tensor, sigma: torch.FloatTensor, denoised: torch.Tensor): + ''' + converts a denoiser output to a Karras ODE derivative + args: + x (torch.Tensor): noisy sample + sigma (torch.FloatTensor): noise level + denoised (torch.Tensor): denoised sample + ''' + return (x - denoised) / sigma.item() + + def calculate_input(self, x: torch.Tensor, sigma: float): + return x / (sigma ** 2 + 1) ** 0.5 + + def calculate_denoised(self, x: torch.Tensor, model_pred: torch.Tensor, sigma: torch.FloatTensor): + """ + Compute Tweedie's formula in VE sampling. + args: + x (torch.Tensor): noisy sample + model_pred (torch.Tensor): estimated noise + alpha (torch.FloatTensor): alpha + """ + return x - model_pred * sigma + + def x_to_denoised(self, x, sigma, uc, c, cfg_guidance, t, added_cond_kwargs): + """ + Get noisy sample and compute denoised samples. + args: + x (torch.Tensor): noisy sample + sigma (float): noise level + uc (torch.Tensor): null-text embedding + c (torch.Tensor): text embedding + cfg_guidance (float): guidance scale + t (torch.Tensor): timestep + """ + xc = self.calculate_input(x, sigma) + noise_uc, noise_c = self.predict_noise(xc, t, uc, c, added_cond_kwargs) + noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc) + denoised = self.calculate_denoised(x, noise_pred, sigma) + uncond_denoised = self.calculate_denoised(x, noise_uc, sigma) + return denoised, uncond_denoised + ########################################### -# Base version +# VP version samplers ########################################### @register_solver('ddim') -class BaseDDIM(SDXL): +class DDIM(SDXL): + """ + Basic DDIM solver for SDXL. + VP sampling. + """ def reverse_process(self, null_prompt_embeds, prompt_embeds, @@ -407,30 +610,9 @@ def reverse_process(self, # for the last stpe, do not add noise return z0t -@register_solver('ddim_lightning') -class BaseDDIMLight(BaseDDIM, SDXLLightning): - def __init__(self, **kwargs): - SDXLLightning.__init__(self, **kwargs) - - def reverse_process(self, - null_prompt_embeds, - prompt_embeds, - cfg_guidance, - add_cond_kwargs, - shape=(1024, 1024), - callback_fn=None, - **kwargs): - assert cfg_guidance == 1.0, "CFG should be turned off in the lightning version" - return super().reverse_process(null_prompt_embeds, - prompt_embeds, - cfg_guidance, - add_cond_kwargs, - shape, - callback_fn, - **kwargs) @register_solver("ddim_edit") -class EditWardSwapDDIM(BaseDDIM): +class EditWardSwapDDIM(DDIM): @torch.autocast(device_type='cuda', dtype=torch.float16) def sample(self, prompt1 = ["", "", ""], @@ -568,13 +750,41 @@ def reverse_process(self, # for the last stpe, do not add noise return z0t +############################################### +# SDXL Lightning with CFG++ +############################################### +@register_solver('ddim_lightning') +class BaseDDIMLight(DDIM, SDXLLightning): + def __init__(self, **kwargs): + SDXLLightning.__init__(self, **kwargs) + + def reverse_process(self, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + add_cond_kwargs, + shape=(1024, 1024), + callback_fn=None, + **kwargs): + assert cfg_guidance == 1.0, "CFG should be turned off in the lightning version" + return super().reverse_process(null_prompt_embeds, + prompt_embeds, + cfg_guidance, + add_cond_kwargs, + shape, + callback_fn, + **kwargs) ########################################### -# CFG++ version +# VE version samplers (K-diffusion) ########################################### -@register_solver("ddim_cfg++") -class BaseDDIMCFGpp(SDXL): +@register_solver("euler") +class EulerCFGSolver(Kdiffusion): + """ + Karras Euler (VE casted) + """ + @torch.autocast(device_type='cuda', dtype=torch.float16) def reverse_process(self, null_prompt_embeds, prompt_embeds, @@ -583,45 +793,46 @@ def reverse_process(self, shape=(1024, 1024), callback_fn=None, **kwargs): - ################################# - # Sample region - where to change - ################################# - # initialize zT - zt = self.initialize_latent(size=(1, 4, shape[1] // self.vae_scale_factor, shape[0] // self.vae_scale_factor)) + x = self.initialize_latent(method="random_kdiffusion", + size=(1, 4, shape[1] // self.vae_scale_factor, shape[0] // self.vae_scale_factor)) # sampling - pbar = tqdm(self.scheduler.timesteps.int(), desc='SDXL') - for step, t in enumerate(pbar): - next_t = t - self.skip - at = self.scheduler.alphas_cumprod[t] - at_next = self.scheduler.alphas_cumprod[next_t] + pbar = tqdm(self.scheduler.timesteps.int(), desc='SDXL-Euler') + for step, _ in enumerate(pbar): + sigma = self.k_sigmas[step] + t = self.sigma_to_t(sigma).to(self.device) with torch.no_grad(): - noise_uc, noise_c = self.predict_noise(zt, t, null_prompt_embeds, prompt_embeds, add_cond_kwargs) - noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc) - - # tweedie - z0t = (zt - (1-at).sqrt() * noise_pred) / at.sqrt() - - # add noise - zt = at_next.sqrt() * z0t + (1-at_next).sqrt() * noise_uc + denoised, _ = self.x_to_denoised(x, + sigma, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + t, + add_cond_kwargs) + + d = self.to_d(x, sigma, denoised) + # Euler method + x = denoised + d * self.k_sigmas[step+1] if callback_fn is not None: - callback_kwargs = { 'z0t': z0t.detach(), - 'zt': zt.detach(), + callback_kwargs = { 'z0t': denoised.detach(), + 'zt': x.detach(), 'decode': self.decode} callback_kwargs = callback_fn(step, t, callback_kwargs) - z0t = callback_kwargs["z0t"] - zt = callback_kwargs["zt"] + denoised = callback_kwargs["z0t"] + x = callback_kwargs["zt"] # for the last stpe, do not add noise - return z0t + return x -@register_solver('ddim_cfg++_lightning') -class BaseDDIMCFGppLight(BaseDDIMCFGpp, SDXLLightning): - def __init__(self, **kwargs): - SDXLLightning.__init__(self, **kwargs) - + +@register_solver("euler_a") +class EulerAncestralCFGSolver(Kdiffusion): + """ + Karras Euler (VE casted) + Ancestral sampling + """ + @torch.autocast(device_type='cuda', dtype=torch.float16) def reverse_process(self, null_prompt_embeds, prompt_embeds, @@ -630,39 +841,45 @@ def reverse_process(self, shape=(1024, 1024), callback_fn=None, **kwargs): - assert cfg_guidance == 1.0, "CFG should be turned off in the lightning version" - return super().reverse_process(null_prompt_embeds, - prompt_embeds, - cfg_guidance, - add_cond_kwargs, - shape, - callback_fn, - **kwargs) + x = self.initialize_latent(method="random_kdiffusion", + size=(1, 4, shape[1] // self.vae_scale_factor, shape[0] // self.vae_scale_factor)) + + # sampling + pbar = tqdm(self.scheduler.timesteps.int(), desc='SDXL-Euler_a') + for step, _ in enumerate(pbar): + sigma = self.k_sigmas[step] + t = self.sigma_to_t(sigma).to(self.device) + sigma_down, sigma_up = get_ancestral_step(self.k_sigmas[step], self.k_sigmas[step+1]) + with torch.no_grad(): + denoised, _ = self.x_to_denoised(x, + sigma, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + t, + add_cond_kwargs) + # Euler method + d = self.to_d(x, sigma, denoised) + x = denoised + d * sigma_down + + if self.k_sigmas[step + 1] > 0: + x = x + torch.randn_like(x) * sigma_up + + if callback_fn is not None: + callback_kwargs = { 'z0t': denoised.detach(), + 'zt': x.detach(), + 'decode': self.decode} + callback_kwargs = callback_fn(step, t, callback_kwargs) + denoised = callback_kwargs["z0t"] + x = callback_kwargs["zt"] + + # for the last stpe, do not add noise + return x -@register_solver('dpm++_2m_cfgpp') -class DPMpp2mCFGppSolver(SDXL): - quantize = True - - def sigma_to_t(self, sigma, quantize=None): - '''Taken from k_diffusion/external.py''' - quantize = self.quantize if quantize is None else quantize - total_sigmas = (1-self.total_alphas).sqrt() / self.total_alphas.sqrt() - dists = sigma - total_sigmas[:, None] - if quantize: - return dists.abs().argmin(dim=0).view(sigma.shape) - low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=total_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - low, high = total_sigmas[low_idx], total_sigmas[high_idx] - w = (low - sigma) / (low - high) - w = w.clamp(0, 1) - t = (1 - w) * low_idx + w * high_idx - return t.view(sigma.shape) - - def to_d(self, x, sigma, denoised): - '''converts a denoiser output to a Karras ODE derivative''' - return (x - denoised) / sigma.item() - @torch.autocast("cuda") +@register_solver("dpm++_2s_a") +class DPMpp2AncestralCFGSolver(Kdiffusion): + @torch.autocast(device_type='cuda', dtype=torch.float16) def reverse_process(self, null_prompt_embeds, prompt_embeds, @@ -671,70 +888,131 @@ def reverse_process(self, shape=(1024, 1024), callback_fn=None, **kwargs): - ################################# - # Sample region - where to change - ################################# - - # prepare alphas and sigmas - alphas = self.scheduler.alphas_cumprod[self.scheduler.timesteps.int().cpu()].cpu() - sigmas = (1-alphas).sqrt() / alphas.sqrt() - - # initialize - x = self.initialize_latent(method='random', - size=(1, 4, shape[1] // self.vae_scale_factor, shape[0] // self.vae_scale_factor)).to(torch.float16) - x = x * sigmas[0] - t_fn = lambda sigma: sigma.log().neg() - old_denoised = None # initial value + sigma_fn = lambda t: t.neg().exp() + x = self.initialize_latent(method="random_kdiffusion", + size=(1, 4, shape[1] // self.vae_scale_factor, shape[0] // self.vae_scale_factor)) + # sampling - pbar = tqdm(self.scheduler.timesteps[:-1].int(), desc='SDXL') - for i, _ in enumerate(pbar): - at = alphas[i] - sigma = sigmas[i] + pbar = tqdm(self.scheduler.timesteps.int(), desc='SDXL-DPM++2s_a') + for step, _ in enumerate(pbar): + sigma = self.k_sigmas[step] + t_1 = self.sigma_to_t(sigma).to(self.device) + with torch.no_grad(): + denoised, _ = self.x_to_denoised(x, + sigma, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + t_1, + add_cond_kwargs) + + sigma_down, sigma_up = get_ancestral_step(self.k_sigmas[step], self.k_sigmas[step + 1]) + if sigma_down == 0: + # Euler method + d = self.to_d(x, self.k_sigmas[step], denoised) + x = denoised + d * sigma_down + else: + # DPM-Solver++(2S) + t, t_next = t_fn(self.k_sigmas[step]), t_fn(sigma_down) + r = 1 / 2 + h = t_next - t + s = t + r * h + x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised + + with torch.no_grad(): + sigma_s = sigma_fn(s) + t_2 = self.sigma_to_t(sigma_s).to(self.device) + denoised_2, _ = self.x_to_denoised(x_2, + sigma_s, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + t_2, + add_cond_kwargs) + + x = denoised_2 - torch.exp(-h) * denoised_2 + (sigma_fn(t_next) / sigma_fn(t)) * x + + # Noise addition + if self.k_sigmas[step + 1] > 0: + x = x + torch.randn_like(x) * sigma_up + + if callback_fn is not None: + callback_kwargs = { 'z0t': denoised.detach(), + 'zt': x.detach(), + 'decode': self.decode} + callback_kwargs = callback_fn(step, t, callback_kwargs) + denoised = callback_kwargs["z0t"] + x = callback_kwargs["zt"] - c_in = at.clone().sqrt() - c_out = -sigma.clone() + # for the last stpe, do not add noise + return x - new_t = self.sigma_to_t(sigma).to(self.device) +@register_solver("dpm++_2m") +class DPMpp2mCFGSolver(Kdiffusion): + @torch.autocast(device_type='cuda', dtype=torch.float16) + def reverse_process(self, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + add_cond_kwargs, + shape=(1024, 1024), + callback_fn=None, + **kwargs): + t_fn = lambda sigma: sigma.log().neg() + sigma_fn = lambda t: t.neg().exp() + x = self.initialize_latent(method="random_kdiffusion", + size=(1, 4, shape[1] // self.vae_scale_factor, shape[0] // self.vae_scale_factor)) + old_denoised=None # buffer + + # sampling + pbar = tqdm(self.scheduler.timesteps.int(), desc='SDXL-DPM++2m') + for step, _ in enumerate(pbar): + sigma = self.k_sigmas[step] + t_1 = self.sigma_to_t(sigma).to(self.device) with torch.no_grad(): - noise_uc, noise_c = self.predict_noise(x * c_in, new_t, null_prompt_embeds, prompt_embeds, add_cond_kwargs) - noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc) - - # tweedie, VE version - denoised = x + c_out * noise_pred - uncond_denoised = x + c_out * noise_uc - + denoised, _ = self.x_to_denoised(x, + sigma, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + t_1, + add_cond_kwargs) # solve ODE one step - t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i+1]) + t, t_next = t_fn(self.k_sigmas[step]), t_fn(self.k_sigmas[step+1]) h = t_next - t - if old_denoised is None or sigmas[i+1] == 0: - x = denoised + self.to_d(x, sigmas[i], uncond_denoised) * sigmas[i+1] + if old_denoised is None or self.k_sigmas[step+1] == 0: + x = denoised + self.to_d(x, self.k_sigmas[step], denoised) * self.k_sigmas[step+1] else: - h_last = t - t_fn(sigmas[i-1]) + h_last = t - t_fn(self.k_sigmas[step-1]) r = h_last / h - extra1 = -torch.exp(-h) * uncond_denoised - (-h).expm1() * (uncond_denoised - old_denoised) / (2*r) + extra1 = -torch.exp(-h) * denoised - (-h).expm1() * (denoised - old_denoised) / (2*r) extra2 = torch.exp(-h) * x x = denoised + extra1 + extra2 - old_denoised = uncond_denoised + old_denoised = denoised if callback_fn is not None: callback_kwargs = { 'z0t': denoised.detach(), 'zt': x.detach(), 'decode': self.decode} - callback_kwargs = callback_fn(i, new_t, callback_kwargs) + callback_kwargs = callback_fn(step, t, callback_kwargs) denoised = callback_kwargs["z0t"] x = callback_kwargs["zt"] # for the last stpe, do not add noise return x -@register_solver('dpm++_2m_cfgpp_lightning') -class DPMpp2mCFGppLightningSolver(DPMpp2mCFGppSolver, SDXLLightning): - def __init__(self, **kwargs): - SDXLLightning.__init__(self, **kwargs) - +########################################### +# VP version samplers with CFG++ +########################################### + +@register_solver("ddim_cfg++") +class DDIMCFGpp(SDXL): + """ + DDIM solver for SDXL with CFG++. + """ def reverse_process(self, null_prompt_embeds, prompt_embeds, @@ -743,14 +1021,39 @@ def reverse_process(self, shape=(1024, 1024), callback_fn=None, **kwargs): - assert cfg_guidance == 1.0, "CFG should be turned off in the lightning version" - return super().reverse_process(null_prompt_embeds, - prompt_embeds, - cfg_guidance, - add_cond_kwargs, - shape, - callback_fn, - **kwargs) + ################################# + # Sample region - where to change + ################################# + # initialize zT + zt = self.initialize_latent(size=(1, 4, shape[1] // self.vae_scale_factor, shape[0] // self.vae_scale_factor)) + + # sampling + pbar = tqdm(self.scheduler.timesteps.int(), desc='SDXL') + for step, t in enumerate(pbar): + next_t = t - self.skip + at = self.scheduler.alphas_cumprod[t] + at_next = self.scheduler.alphas_cumprod[next_t] + + with torch.no_grad(): + noise_uc, noise_c = self.predict_noise(zt, t, null_prompt_embeds, prompt_embeds, add_cond_kwargs) + noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc) + + # tweedie + z0t = (zt - (1-at).sqrt() * noise_pred) / at.sqrt() + + # add noise + zt = at_next.sqrt() * z0t + (1-at_next).sqrt() * noise_uc + + if callback_fn is not None: + callback_kwargs = { 'z0t': z0t.detach(), + 'zt': zt.detach(), + 'decode': self.decode} + callback_kwargs = callback_fn(step, t, callback_kwargs) + z0t = callback_kwargs["z0t"] + zt = callback_kwargs["zt"] + + # for the last stpe, do not add noise + return z0t @register_solver("ddim_edit_cfg++") class EditWardSwapDDIMCFGpp(EditWardSwapDDIM): @@ -824,6 +1127,280 @@ def reverse_process(self, # for the last stpe, do not add noise return z0t + + +############################################### +# SDXL Lightning with CFG++ +############################################### + +@register_solver('ddim_cfg++_lightning') +class BaseDDIMCFGppLight(DDIMCFGpp, SDXLLightning): + def __init__(self, **kwargs): + SDXLLightning.__init__(self, **kwargs) + + def reverse_process(self, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + add_cond_kwargs, + shape=(1024, 1024), + callback_fn=None, + **kwargs): + assert cfg_guidance == 1.0, "CFG should be turned off in the lightning version" + return super().reverse_process(null_prompt_embeds, + prompt_embeds, + cfg_guidance, + add_cond_kwargs, + shape, + callback_fn, + **kwargs) + + +############################################### +# VE version samplers (K-diffusion) with CFG++ +############################################### + +@register_solver("euler_cfg++") +class EulerCFGppSolver(Kdiffusion): + @torch.autocast(device_type='cuda', dtype=torch.float16) + def reverse_process(self, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + add_cond_kwargs, + shape=(1024, 1024), + callback_fn=None, + **kwargs): + x = self.initialize_latent(method="random_kdiffusion", + size=(1, 4, shape[1] // self.vae_scale_factor, shape[0] // self.vae_scale_factor)) + + # sampling + pbar = tqdm(self.scheduler.timesteps.int(), desc='SDXL-Euler') + for step, _ in enumerate(pbar): + sigma = self.k_sigmas[step] + t = self.sigma_to_t(sigma).to(self.device) + + with torch.no_grad(): + denoised, uncond_denoised = self.x_to_denoised(x, + sigma, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + t, + add_cond_kwargs) + + d = self.to_d(x, sigma, uncond_denoised) + # Euler method + x = denoised + d * self.k_sigmas[step+1] + + if callback_fn is not None: + callback_kwargs = { 'z0t': denoised.detach(), + 'zt': x.detach(), + 'decode': self.decode} + callback_kwargs = callback_fn(step, t, callback_kwargs) + denoised = callback_kwargs["z0t"] + x = callback_kwargs["zt"] + + # for the last stpe, do not add noise + return x + + +@register_solver("euler_a_cfg++") +class EulerAncestralCFGppSolver(Kdiffusion): + @torch.autocast(device_type='cuda', dtype=torch.float16) + def reverse_process(self, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + add_cond_kwargs, + shape=(1024, 1024), + callback_fn=None, + **kwargs): + x = self.initialize_latent(method="random_kdiffusion", + size=(1, 4, shape[1] // self.vae_scale_factor, shape[0] // self.vae_scale_factor)) + + # sampling + pbar = tqdm(self.scheduler.timesteps.int(), desc='SDXL-Euler_a') + for step, _ in enumerate(pbar): + sigma = self.k_sigmas[step] + t = self.sigma_to_t(sigma).to(self.device) + sigma_down, sigma_up = get_ancestral_step(self.k_sigmas[step], self.k_sigmas[step+1]) + with torch.no_grad(): + denoised, uncond_denoised = self.x_to_denoised(x, + sigma, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + t, + add_cond_kwargs) + # Euler method + d = self.to_d(x, sigma, uncond_denoised) + x = denoised + d * sigma_down + + if self.k_sigmas[step + 1] > 0: + x = x + torch.randn_like(x) * sigma_up + + if callback_fn is not None: + callback_kwargs = { 'z0t': denoised.detach(), + 'zt': x.detach(), + 'decode': self.decode} + callback_kwargs = callback_fn(step, t, callback_kwargs) + denoised = callback_kwargs["z0t"] + x = callback_kwargs["zt"] + + # for the last stpe, do not add noise + return x + +@register_solver("dpm++_2s_a_cfg++") +class DPMpp2sAncestralCFGppSolver(Kdiffusion): + @torch.autocast(device_type='cuda', dtype=torch.float16) + def reverse_process(self, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + add_cond_kwargs, + shape=(1024, 1024), + callback_fn=None, + **kwargs): + t_fn = lambda sigma: sigma.log().neg() + sigma_fn = lambda t: t.neg().exp() + + x = self.initialize_latent(method="random_kdiffusion", + size=(1, 4, shape[1] // self.vae_scale_factor, shape[0] // self.vae_scale_factor)) + + # sampling + pbar = tqdm(self.scheduler.timesteps.int(), desc='SDXL-DPM++2s_a') + for step, _ in enumerate(pbar): + sigma = self.k_sigmas[step] + t_1 = self.sigma_to_t(sigma).to(self.device) + with torch.no_grad(): + denoised, uncond_denoised = self.x_to_denoised(x, + sigma, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + t_1, + add_cond_kwargs) + + sigma_down, sigma_up = get_ancestral_step(self.k_sigmas[step], self.k_sigmas[step + 1]) + if sigma_down == 0: + # Euler method + d = self.to_d(x, self.k_sigmas[step], uncond_denoised) + x = denoised + d * sigma_down + else: + # DPM-Solver++(2S) + t, t_next = t_fn(self.k_sigmas[step]), t_fn(sigma_down) + r = 1 / 2 + h = t_next - t + s = t + r * h + x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * uncond_denoised + + with torch.no_grad(): + sigma_s = sigma_fn(s) + t_2 = self.sigma_to_t(sigma_s).to(self.device) + denoised_2, uncond_denoised_2 = self.x_to_denoised(x_2, + sigma_s, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + t_2, + add_cond_kwargs) + + x = denoised_2 - torch.exp(-h) * uncond_denoised_2 + (sigma_fn(t_next) / sigma_fn(t)) * x + + # Noise addition + if self.k_sigmas[step + 1] > 0: + x = x + torch.randn_like(x) * sigma_up + + if callback_fn is not None: + callback_kwargs = { 'z0t': denoised.detach(), + 'zt': x.detach(), + 'decode': self.decode} + callback_kwargs = callback_fn(step, t, callback_kwargs) + denoised = callback_kwargs["z0t"] + x = callback_kwargs["zt"] + + # for the last stpe, do not add noise + return x + + +@register_solver("dpm++_2m_cfg++") +class DPMpp2mCFGppSolver(Kdiffusion): + @torch.autocast(device_type='cuda', dtype=torch.float16) + def reverse_process(self, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + add_cond_kwargs, + shape=(1024, 1024), + callback_fn=None, + **kwargs): + t_fn = lambda sigma: sigma.log().neg() + sigma_fn = lambda t: t.neg().exp() + + x = self.initialize_latent(method="random_kdiffusion", + size=(1, 4, shape[1] // self.vae_scale_factor, shape[0] // self.vae_scale_factor)) + old_denoised=None # buffer + + # sampling + pbar = tqdm(self.scheduler.timesteps.int(), desc='SDXL-DPM++2m') + for step, _ in enumerate(pbar): + sigma = self.k_sigmas[step] + t_1 = self.sigma_to_t(sigma).to(self.device) + with torch.no_grad(): + denoised, uncond_denoised = self.x_to_denoised(x, + sigma, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + t_1, + add_cond_kwargs) + # solve ODE one step + t, t_next = t_fn(self.k_sigmas[step]), t_fn(self.k_sigmas[step+1]) + h = t_next - t + if old_denoised is None or self.k_sigmas[step+1] == 0: + x = denoised + self.to_d(x, self.k_sigmas[step], uncond_denoised) * self.k_sigmas[step+1] + else: + h_last = t - t_fn(self.k_sigmas[step-1]) + r = h_last / h + extra1 = -torch.exp(-h) * uncond_denoised - (-h).expm1() * (denoised - old_denoised) / (2*r) + extra2 = torch.exp(-h) * x + x = denoised + extra1 + extra2 + old_denoised = uncond_denoised + + if callback_fn is not None: + callback_kwargs = { 'z0t': denoised.detach(), + 'zt': x.detach(), + 'decode': self.decode} + callback_kwargs = callback_fn(step, t, callback_kwargs) + denoised = callback_kwargs["z0t"] + x = callback_kwargs["zt"] + + # for the last stpe, do not add noise + return x + +@register_solver('dpm++_2m_cfg++_lightning') +class DPMpp2mCFGppLightningSolver(DPMpp2mCFGppSolver, SDXLLightning): + def __init__(self, **kwargs): + SDXLLightning.__init__(self, **kwargs) + + def reverse_process(self, + null_prompt_embeds, + prompt_embeds, + cfg_guidance, + add_cond_kwargs, + shape=(1024, 1024), + callback_fn=None, + **kwargs): + assert cfg_guidance == 1.0, "CFG should be turned off in the lightning version" + return super().reverse_process(null_prompt_embeds, + prompt_embeds, + cfg_guidance, + add_cond_kwargs, + shape, + callback_fn, + **kwargs) + ############################# if __name__ == "__main__":