diff --git a/imagebind/models/imagebind_model.py b/imagebind/models/imagebind_model.py index c560945f..e564fdea 100644 --- a/imagebind/models/imagebind_model.py +++ b/imagebind/models/imagebind_model.py @@ -490,17 +490,18 @@ def imagebind_huge(pretrained=False): ) if pretrained: - if not os.path.exists(".checkpoints/imagebind_huge.pth"): + cache_dir = os.path.expanduser("~/.cache/imagebind_checkpoints") + if not os.path.exists("%s/imagebind_huge.pth" % (cache_dir)): print( - "Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..." + "Downloading imagebind weights to %s/imagebind_huge.pth ..." % (cache_dir) ) - os.makedirs(".checkpoints", exist_ok=True) + os.makedirs(cache_dir, exist_ok=True) torch.hub.download_url_to_file( "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth", - ".checkpoints/imagebind_huge.pth", + "%s/imagebind_huge.pth" % (cache_dir), progress=True, ) - model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth")) + model.load_state_dict(torch.load("%s/imagebind_huge.pth" % (cache_dir))) return model