diff --git a/classy_vision/models/classy_model.py b/classy_vision/models/classy_model.py index fb18250ba0..b4b1eab516 100644 --- a/classy_vision/models/classy_model.py +++ b/classy_vision/models/classy_model.py @@ -201,7 +201,7 @@ def from_checkpoint(cls, checkpoint): model.set_classy_state(checkpoint["classy_state_dict"]["base_model"]) return model - def get_classy_state(self, deep_copy=False): + def state_dict(self, deep_copy=False): """Get the state of the ClassyModel. The returned state is used for checkpointing. @@ -222,7 +222,7 @@ def get_classy_state(self, deep_copy=False): # as the trunk state. If the model doesn't have heads attached, all of the # model's state lives in the trunk. self.clear_heads() - trunk_state_dict = self.state_dict() + trunk_state_dict = super().state_dict() self.set_heads(attached_heads) head_state_dict = {} @@ -252,7 +252,7 @@ def load_head_states(self, state, strict=True): for head_name, head_state in head_states.items(): self._heads[block_name][head_name].load_state_dict(head_state, strict) - def set_classy_state(self, state, strict=True): + def load_state_dict(self, state, strict=True): """Set the state of the ClassyModel. Args: @@ -270,11 +270,17 @@ def set_classy_state(self, state, strict=True): # fetched / set when there are no blocks attached. attached_heads = self.get_heads() self.clear_heads() - self.load_state_dict(state["model"]["trunk"], strict) + super().load_state_dict(state["model"]["trunk"], strict) # set the heads back again self.set_heads(attached_heads) + def get_classy_state(self, deep_copy=False): + return self.state_dict(deep_copy=deep_copy) + + def set_classy_state(self, state, strict=True): + self.load_state_dict(state, strict=strict) + def forward(self, x): """ Perform computation of blocks in the order define in get_blocks.