Skip to content

Commit 9375844

Browse files
committed
Enhance docstrings and type hints in scheduling_ddim.py
- Update parameter types and descriptions for clarity - Improve explanations in method docstrings to align with project standards - Add optional annotations for parameters where applicable
1 parent 67f931b commit 9375844

File tree

1 file changed

+29
-20
lines changed

1 file changed

+29
-20
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class DDIMSchedulerOutput(BaseOutput):
3838
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
3939
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
4040
denoising loop.
41-
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
41+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images, *optional*):
4242
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
4343
`pred_original_sample` can be used to preview progress or for guidance.
4444
"""
@@ -375,7 +375,7 @@ def step(
375375
sample: torch.Tensor,
376376
eta: float = 0.0,
377377
use_clipped_model_output: bool = False,
378-
generator=None,
378+
generator: Optional[torch.Generator] = None,
379379
variance_noise: Optional[torch.Tensor] = None,
380380
return_dict: bool = True,
381381
) -> Union[DDIMSchedulerOutput, Tuple]:
@@ -386,20 +386,21 @@ def step(
386386
Args:
387387
model_output (`torch.Tensor`):
388388
The direct output from learned diffusion model.
389-
timestep (`float`):
389+
timestep (`int`):
390390
The current discrete timestep in the diffusion chain.
391391
sample (`torch.Tensor`):
392392
A current instance of a sample created by the diffusion process.
393-
eta (`float`):
394-
The weight of noise for added noise in diffusion step.
395-
use_clipped_model_output (`bool`, defaults to `False`):
393+
eta (`float`, *optional*, defaults to 0.0):
394+
The weight of noise for added noise in diffusion step. A value of 0 corresponds to DDIM (deterministic)
395+
and 1 corresponds to DDPM (fully stochastic).
396+
use_clipped_model_output (`bool`, *optional*, defaults to `False`):
396397
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
397398
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
398399
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
399400
`use_clipped_model_output` has no effect.
400401
generator (`torch.Generator`, *optional*):
401-
A random number generator.
402-
variance_noise (`torch.Tensor`):
402+
A random number generator for reproducible sampling.
403+
variance_noise (`torch.Tensor`, *optional*):
403404
Alternative to generating noise with `generator` by directly providing the noise for the variance
404405
itself. Useful for methods such as [`CycleDiffusion`].
405406
return_dict (`bool`, *optional*, defaults to `True`):
@@ -507,19 +508,22 @@ def add_noise(
507508
timesteps: torch.IntTensor,
508509
) -> torch.Tensor:
509510
"""
510-
Adds noise to the original samples.
511+
Add noise to the original samples according to the noise magnitude at each timestep.
512+
513+
This implements the forward diffusion process using the formula: `noisy_sample = sqrt(alpha_prod) *
514+
original_sample + sqrt(1 - alpha_prod) * noise`
511515
512516
Args:
513517
original_samples (`torch.Tensor`):
514-
The original samples to add noise to.
518+
The original clean samples to which noise will be added.
515519
noise (`torch.Tensor`):
516-
The noise to add to the original samples.
520+
The noise tensor to add, typically sampled from a Gaussian distribution.
517521
timesteps (`torch.IntTensor`):
518-
The timesteps to add noise to.
522+
The timesteps indicating the noise level from the diffusion schedule.
519523
520524
Returns:
521525
`torch.Tensor`:
522-
The noisy samples.
526+
The noisy samples with noise added according to the timestep schedule.
523527
"""
524528
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
525529
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
@@ -544,20 +548,25 @@ def add_noise(
544548
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
545549
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
546550
"""
547-
Computes the velocity of the sample. The velocity is defined as the difference between the original sample and
548-
the noisy sample. See https://huggingface.co/papers/2010.02502
551+
Compute the velocity prediction for v-prediction models.
552+
553+
The velocity is computed using the formula: `velocity = sqrt(alpha_prod) * noise - sqrt(1 - alpha_prod) *
554+
sample`
555+
556+
This is used in v-prediction models where the model directly predicts the velocity instead of the noise or the
557+
sample. See section 2.4 of Imagen Video paper: https://imagen.research.google/video/paper.pdf
549558
550559
Args:
551560
sample (`torch.Tensor`):
552-
The sample to compute the velocity of.
561+
The input sample (x_t) at the current timestep.
553562
noise (`torch.Tensor`):
554-
The noise to compute the velocity of.
563+
The noise tensor corresponding to the sample.
555564
timesteps (`torch.IntTensor`):
556-
The timesteps to compute the velocity of.
565+
The timesteps at which to compute the velocity.
557566
558567
Returns:
559568
`torch.Tensor`:
560-
The velocity of the sample.
569+
The velocity prediction computed from the sample and noise at the given timesteps.
561570
"""
562571
# Make sure alphas_cumprod and timestep have same device and dtype as sample
563572
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
@@ -577,5 +586,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
577586
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
578587
return velocity
579588

580-
def __len__(self):
589+
def __len__(self) -> int:
581590
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)