Skip to content

Commit 26f0ca1

Browse files
committed
Refine type hints and docstrings in scheduling_ddim.py
- Update parameter types to use Literal for specific string options - Enhance docstring descriptions for clarity and consistency - Ensure all parameters have appropriate type annotations and defaults
1 parent 9375844 commit 26f0ca1

File tree

1 file changed

+22
-20
lines changed

1 file changed

+22
-20
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import math
1919
from dataclasses import dataclass
20-
from typing import List, Optional, Tuple, Union
20+
from typing import List, Literal, Optional, Tuple, Union
2121

2222
import numpy as np
2323
import torch
@@ -51,7 +51,7 @@ class DDIMSchedulerOutput(BaseOutput):
5151
def betas_for_alpha_bar(
5252
num_diffusion_timesteps: int,
5353
max_beta: float = 0.999,
54-
alpha_transform_type: str = "cosine",
54+
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
5555
) -> torch.Tensor:
5656
"""
5757
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -61,14 +61,15 @@ def betas_for_alpha_bar(
6161
to that part of the diffusion process.
6262
6363
Args:
64-
num_diffusion_timesteps (`int`): the number of betas to produce.
65-
max_beta (`float`): the maximum beta to use; use values lower than 1 to
66-
prevent singularities.
67-
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
68-
Choose from `cosine` or `exp`
64+
num_diffusion_timesteps (`int`):
65+
The number of betas to produce.
66+
max_beta (`float`, defaults to 0.999):
67+
The maximum beta to use; use values lower than 1 to prevent singularities.
68+
alpha_transform_type (`Literal["cosine", "exp"]`, defaults to `"cosine"`):
69+
The type of noise schedule for alpha_bar. Must be one of `"cosine"` or `"exp"`.
6970
7071
Returns:
71-
betas (`torch.Tensor`): the betas used by the scheduler to step the model outputs
72+
`torch.Tensor`: The betas used by the scheduler to step the model outputs.
7273
"""
7374
if alpha_transform_type == "cosine":
7475

@@ -141,9 +142,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
141142
The starting `beta` value of inference.
142143
beta_end (`float`, defaults to 0.02):
143144
The final `beta` value.
144-
beta_schedule (`str`, defaults to `"linear"`):
145-
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
146-
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
145+
beta_schedule (`Literal["linear", "scaled_linear", "squaredcos_cap_v2"]`, defaults to `"linear"`):
146+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Must be one
147+
of `"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`.
147148
trained_betas (`np.ndarray`, *optional*):
148149
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
149150
clip_sample (`bool`, defaults to `True`):
@@ -156,9 +157,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
156157
otherwise it uses the alpha value at step 0.
157158
steps_offset (`int`, defaults to 0):
158159
An offset added to the inference steps, as required by some model families.
159-
prediction_type (`str`, defaults to `epsilon`, *optional*):
160-
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
161-
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
160+
prediction_type (`Literal["epsilon", "sample", "v_prediction"]`, defaults to `"epsilon"`):
161+
Prediction type of the scheduler function. Must be one of `"epsilon"` (predicts the noise of the diffusion
162+
process), `"sample"` (directly predicts the noisy sample), or `"v_prediction"` (see section 2.4 of [Imagen
162163
Video](https://imagen.research.google/video/paper.pdf) paper).
163164
thresholding (`bool`, defaults to `False`):
164165
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
@@ -167,9 +168,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
167168
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
168169
sample_max_value (`float`, defaults to 1.0):
169170
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
170-
timestep_spacing (`str`, defaults to `"leading"`):
171-
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
172-
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
171+
timestep_spacing (`Literal["leading", "trailing", "linspace"]`, defaults to `"leading"`):
172+
The way the timesteps should be scaled. Must be one of `"leading"`, `"trailing"`, or `"linspace"`. Refer to
173+
Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are
174+
Flawed](https://huggingface.co/papers/2305.08891) for more information.
173175
rescale_betas_zero_snr (`bool`, defaults to `False`):
174176
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
175177
dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -185,17 +187,17 @@ def __init__(
185187
num_train_timesteps: int = 1000,
186188
beta_start: float = 0.0001,
187189
beta_end: float = 0.02,
188-
beta_schedule: str = "linear",
190+
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
189191
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
190192
clip_sample: bool = True,
191193
set_alpha_to_one: bool = True,
192194
steps_offset: int = 0,
193-
prediction_type: str = "epsilon",
195+
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
194196
thresholding: bool = False,
195197
dynamic_thresholding_ratio: float = 0.995,
196198
clip_sample_range: float = 1.0,
197199
sample_max_value: float = 1.0,
198-
timestep_spacing: str = "leading",
200+
timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
199201
rescale_betas_zero_snr: bool = False,
200202
):
201203
if trained_betas is not None:

0 commit comments

Comments
 (0)