From 4d30ff11e3ecc5eab9d10fc603b090d950d49613 Mon Sep 17 00:00:00 2001 From: Vishaal Udandarao Date: Wed, 10 May 2023 13:04:06 +0530 Subject: [PATCH] Add cache_dir impl --- models/imagebind_model.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/models/imagebind_model.py b/models/imagebind_model.py index 395aabf4..a93aeaf1 100644 --- a/models/imagebind_model.py +++ b/models/imagebind_model.py @@ -487,7 +487,7 @@ def forward(self, inputs): return outputs -def imagebind_huge(pretrained=False): +def imagebind_huge(pretrained=False, cache_dir=None): model = ImageBindModel( vision_embed_dim=1280, vision_num_blocks=32, @@ -501,17 +501,20 @@ def imagebind_huge(pretrained=False): ) if pretrained: - if not os.path.exists(".checkpoints/imagebind_huge.pth"): + cache_dir = cache_dir if cache_dir is not None else ".checkpoints" + ckpt_path = "{}/imagebind_huge.pth".format(cache_dir) + + if not os.path.exists(ckpt_path): print( - "Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..." + "Downloading imagebind weights to {} ...".format(ckpt_path) ) - os.makedirs(".checkpoints", exist_ok=True) + os.makedirs(cache_dir, exist_ok=True) torch.hub.download_url_to_file( "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth", - ".checkpoints/imagebind_huge.pth", + ckpt_path, progress=True, ) - model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth")) + model.load_state_dict(torch.load(ckpt_path)) return model