diff --git a/sam3/model_builder.py b/sam3/model_builder.py index 103b324d..246266d0 100644 --- a/sam3/model_builder.py +++ b/sam3/model_builder.py @@ -540,6 +540,13 @@ def _load_checkpoint(model, checkpoint_path): if "tracker" in k } ) + sam3_image_ckpt.update( + { + k.replace("detector.backbone.", "inst_interactive_predictor.model.backbone."): v + for k, v in ckpt.items() + if k.startswith("detector.backbone.") + } + ) missing_keys, _ = model.load_state_dict(sam3_image_ckpt, strict=False) if len(missing_keys) > 0: print( @@ -615,7 +622,12 @@ def build_sam3_image_model( # Create geometry encoder input_geometry_encoder = _create_geometry_encoder() if enable_inst_interactivity: - sam3_pvs_base = build_tracker(apply_temporal_disambiguation=False) + # Build with backbone for point-prompting + sam3_pvs_base = build_tracker( + apply_temporal_disambiguation=False, + with_backbone=True, + compile_mode=compile_mode, + ) inst_predictor = SAM3InteractiveImagePredictor(sam3_pvs_base) else: inst_predictor = None