Skip to content

SAM2UNet-MSD.pth chkpt file does not contain key model #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
msudhanshu-nus opened this issue Apr 17, 2025 · 3 comments
Open

SAM2UNet-MSD.pth chkpt file does not contain key model #47

msudhanshu-nus opened this issue Apr 17, 2025 · 3 comments

Comments

@msudhanshu-nus
Copy link

When I am loading this kept in train.sh, running it is producing the following error:
Traceback (most recent call last):
File "/mnt/iMVR/sudhanshu/SAM2-UNet/train.py", line 87, in
main(args)
File "/mnt/iMVR/sudhanshu/SAM2-UNet/train.py", line 46, in main
model = SAM2UNet(args.hiera_path)
File "/mnt/iMVR/sudhanshu/SAM2-UNet/SAM2UNet.py", line 129, in init
model = build_sam2(model_cfg, checkpoint_path)
File "/mnt/iMVR/sudhanshu/SAM2-UNet/sam2/build_sam.py", line 36, in build_sam2
_load_checkpoint(model, ckpt_path)
File "/mnt/iMVR/sudhanshu/SAM2-UNet/sam2/build_sam.py", line 88, in _load_checkpoint
sd = torch.load(ckpt_path, map_location="cpu")["model"]
KeyError: 'model'

I checked the keys of this kept, but could not find key "model" in it, and therefore this error.
Kindly check if the kept file uploaded is alright.

Regards,
Sudhanshu.

@xiongxyowo
Copy link
Collaborator

Hi, if you want to train a new model starting from the original SAM2, the "args.hiera_path" should be filled with the checkpoint path of SAM2 itself, such as sam2_hiera_large.pt:

# train.py
args.hiera_path = "./sam2_hiera_large.pt"
model = SAM2UNet(args.hiera_path)

If you want to test the trained SAM2-UNet model, you can load the weights of SAM2-UNet, such as "SAM2UNet-MSD.pth":

# test.py
args.checkpoint = "./SAM2UNet-MSD.pth"
model = SAM2UNet()
model.load_state_dict(torch.load(args.checkpoint), strict=True)

@msudhanshu-nus
Copy link
Author

I want to use SAM2UNet-MSD.pth to train on my dataset for fine-tuning purpose. Basically, the camaflouged features learned are already there in in SAM2UNet-MSD.pth and will be useful to me in further fine-tuning. Can it be done?

Basically, instead of sam2_hiera_large.pt, I want to put SAM2UNet-MSD.pth which is causing the issue.

train.py

args.hiera_path = "./SAM2UNet-MSD.pth"
model = SAM2UNet(args.hiera_path)

@xiongxyowo
Copy link
Collaborator

It is possible to use a pre-trained SAM2-UNet during the training phase. Note that when further fine-tuning on SAM2-UNet, the original weight of SAM2 is no longer necessary:

# train.py
# args.hiera_path = "./sam2_hiera_large.pt"
args.checkpoint = "./SAM2UNet-MSD.pth"
model = SAM2UNet()
model.load_state_dict(torch.load(args.checkpoint), strict=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants