diff --git a/improved_diffusion/dist_util.py b/improved_diffusion/dist_util.py index f665604d6b..d2c509be64 100644 --- a/improved_diffusion/dist_util.py +++ b/improved_diffusion/dist_util.py @@ -69,7 +69,7 @@ def sync_params(params): """ for p in params: with th.no_grad(): - dist.broadcast(p, 0) + dist.broadcast(p.detach(), 0) def _find_free_port():