1717
1818import math
1919from dataclasses import dataclass
20- from typing import List , Optional , Tuple , Union
20+ from typing import List , Literal , Optional , Tuple , Union
2121
2222import numpy as np
2323import torch
@@ -51,7 +51,7 @@ class DDIMSchedulerOutput(BaseOutput):
5151def betas_for_alpha_bar (
5252 num_diffusion_timesteps : int ,
5353 max_beta : float = 0.999 ,
54- alpha_transform_type : str = "cosine" ,
54+ alpha_transform_type : Literal [ "cosine" , "exp" ] = "cosine" ,
5555) -> torch .Tensor :
5656 """
5757 Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -61,14 +61,15 @@ def betas_for_alpha_bar(
6161 to that part of the diffusion process.
6262
6363 Args:
64- num_diffusion_timesteps (`int`): the number of betas to produce.
65- max_beta (`float`): the maximum beta to use; use values lower than 1 to
66- prevent singularities.
67- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
68- Choose from `cosine` or `exp`
64+ num_diffusion_timesteps (`int`):
65+ The number of betas to produce.
66+ max_beta (`float`, defaults to 0.999):
67+ The maximum beta to use; use values lower than 1 to prevent singularities.
68+ alpha_transform_type (`Literal["cosine", "exp"]`, defaults to `"cosine"`):
69+ The type of noise schedule for alpha_bar. Must be one of `"cosine"` or `"exp"`.
6970
7071 Returns:
71- betas ( `torch.Tensor`): the betas used by the scheduler to step the model outputs
72+ `torch.Tensor`: The betas used by the scheduler to step the model outputs.
7273 """
7374 if alpha_transform_type == "cosine" :
7475
@@ -141,9 +142,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
141142 The starting `beta` value of inference.
142143 beta_end (`float`, defaults to 0.02):
143144 The final `beta` value.
144- beta_schedule (`str `, defaults to `"linear"`):
145- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
146- ` linear`, `scaled_linear`, or `squaredcos_cap_v2`.
145+ beta_schedule (`Literal["linear", "scaled_linear", "squaredcos_cap_v2"] `, defaults to `"linear"`):
146+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Must be one
147+ of `" linear" `, `" scaled_linear" `, or `" squaredcos_cap_v2" `.
147148 trained_betas (`np.ndarray`, *optional*):
148149 Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
149150 clip_sample (`bool`, defaults to `True`):
@@ -156,9 +157,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
156157 otherwise it uses the alpha value at step 0.
157158 steps_offset (`int`, defaults to 0):
158159 An offset added to the inference steps, as required by some model families.
159- prediction_type (`str `, defaults to `epsilon`, *optional* ):
160- Prediction type of the scheduler function; can be ` epsilon` (predicts the noise of the diffusion process),
161- ` sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
160+ prediction_type (`Literal["epsilon", "sample", "v_prediction"] `, defaults to `" epsilon"` ):
161+ Prediction type of the scheduler function. Must be one of `" epsilon" ` (predicts the noise of the diffusion
162+ process), `" sample" ` (directly predicts the noisy sample), or `" v_prediction" ` (see section 2.4 of [Imagen
162163 Video](https://imagen.research.google/video/paper.pdf) paper).
163164 thresholding (`bool`, defaults to `False`):
164165 Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
@@ -167,9 +168,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
167168 The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
168169 sample_max_value (`float`, defaults to 1.0):
169170 The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
170- timestep_spacing (`str`, defaults to `"leading"`):
171- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
172- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
171+ timestep_spacing (`Literal["leading", "trailing", "linspace"]`, defaults to `"leading"`):
172+ The way the timesteps should be scaled. Must be one of `"leading"`, `"trailing"`, or `"linspace"`. Refer to
173+ Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are
174+ Flawed](https://huggingface.co/papers/2305.08891) for more information.
173175 rescale_betas_zero_snr (`bool`, defaults to `False`):
174176 Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
175177 dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -185,17 +187,17 @@ def __init__(
185187 num_train_timesteps : int = 1000 ,
186188 beta_start : float = 0.0001 ,
187189 beta_end : float = 0.02 ,
188- beta_schedule : str = "linear" ,
190+ beta_schedule : Literal [ "linear" , "scaled_linear" , "squaredcos_cap_v2" ] = "linear" ,
189191 trained_betas : Optional [Union [np .ndarray , List [float ]]] = None ,
190192 clip_sample : bool = True ,
191193 set_alpha_to_one : bool = True ,
192194 steps_offset : int = 0 ,
193- prediction_type : str = "epsilon" ,
195+ prediction_type : Literal [ "epsilon" , "sample" , "v_prediction" ] = "epsilon" ,
194196 thresholding : bool = False ,
195197 dynamic_thresholding_ratio : float = 0.995 ,
196198 clip_sample_range : float = 1.0 ,
197199 sample_max_value : float = 1.0 ,
198- timestep_spacing : str = "leading" ,
200+ timestep_spacing : Literal [ "leading" , "trailing" , "linspace" ] = "leading" ,
199201 rescale_betas_zero_snr : bool = False ,
200202 ):
201203 if trained_betas is not None :
0 commit comments