diff --git a/scripts/txt2img.py b/scripts/txt2img.py index dc377525e..e66958870 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -3,7 +3,7 @@ import numpy as np from omegaconf import OmegaConf from PIL import Image -from tqdm import tqdm, trange +from tqdm.auto import tqdm, trange from einops import rearrange from torchvision.utils import make_grid @@ -13,7 +13,7 @@ def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") + pl_sd = torch.load(ckpt, map_location="cuda") sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False)