[Pytorch] Pytorch only schedulers#534
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
@kashif, I hope it made sense. Please let me know if that's not what you expected me to check as part of the |
anton-l
left a comment
There was a problem hiding this comment.
I like the changes overall, and the tests are mostly running smoothly, thank you @kashif!
Getting a device mismatch due to self.sigmas always being on cpu here:
> noisy_samples = original_samples + noise * sigma
E RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
../src/diffusers/schedulers/scheduling_lms_discrete.py:209: RuntimeErrorSo scheduler.to(device) might have to be implemented. But I haven't thought about a solution too much, so maybe you have a workaround.
|
@patrickvonplaten @patil-suraj could you give this PR a quick review if you have time? It'll be easier to rebase #637 if this is merged first. |
We already have a move immediately before: I would suggest something like this in this case: sigmas = self.sigmas.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)Unless we want to do computation in CPU as I think we do in the other schedulers. |
|
@pcuenca fixed |
|
|
||
| def __init__(self, unet, scheduler): | ||
| super().__init__() | ||
| scheduler = scheduler.set_format("pt") |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Once tests are green, let's merge this one as it's quite important :-)
|
Great work @kashif! |
* pytorch only schedulers * fix style * remove match_shape * pytorch only ddpm * remove SchedulerMixin * remove numpy from karras_ve * fix types * remove numpy from lms_discrete * remove numpy from pndm * fix typo * remove mixin and numpy from sde_vp and ve * remove remaining tensor_format * fix style * sigmas has to be torch tensor * removed set_format in readme * remove set format from docs * remove set_format from pipelines * update tests * fix typo * continue to use mixin * fix imports * removed unsed imports * match shape instead of assuming image shapes * remove import typo * update call to add_noise * use math instead of numpy * fix t_index * removed commented out numpy tests * timesteps needs to be discrete * cast timesteps to int in flax scheduler too * fix device mismatch issue * small fix * Update src/diffusers/schedulers/scheduling_pndm.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Remove numpy clauses from schedulers to make them pytorch only and fixed use of timesteps in pipelines