From b73871ed9ce0c9aa44878b301221479e1d4b1c07 Mon Sep 17 00:00:00 2001 From: Mannat Singh Date: Tue, 23 Mar 2021 21:06:23 -0700 Subject: [PATCH] Model state should support PyTorch API (#727) Summary: Pull Request resolved: https://github.com/facebookresearch/ClassyVision/pull/727 Classy Models should work like regular PyTorch models. The `{get, set}_classy_state` functions for state are the only blockers which this diff fixes by moving over to `state_dict` and `load_state_dict`. `{get, set}_classy_state` will still work for backwards compatibility, but will call the PyTorch functions directly. Differential Revision: D25213283 fbshipit-source-id: 3cd64f530de83574174884d3b8848a1f6003f854 --- classy_vision/models/classy_model.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/classy_vision/models/classy_model.py b/classy_vision/models/classy_model.py index fb18250ba0..b4b1eab516 100644 --- a/classy_vision/models/classy_model.py +++ b/classy_vision/models/classy_model.py @@ -201,7 +201,7 @@ def from_checkpoint(cls, checkpoint): model.set_classy_state(checkpoint["classy_state_dict"]["base_model"]) return model - def get_classy_state(self, deep_copy=False): + def state_dict(self, deep_copy=False): """Get the state of the ClassyModel. The returned state is used for checkpointing. @@ -222,7 +222,7 @@ def get_classy_state(self, deep_copy=False): # as the trunk state. If the model doesn't have heads attached, all of the # model's state lives in the trunk. self.clear_heads() - trunk_state_dict = self.state_dict() + trunk_state_dict = super().state_dict() self.set_heads(attached_heads) head_state_dict = {} @@ -252,7 +252,7 @@ def load_head_states(self, state, strict=True): for head_name, head_state in head_states.items(): self._heads[block_name][head_name].load_state_dict(head_state, strict) - def set_classy_state(self, state, strict=True): + def load_state_dict(self, state, strict=True): """Set the state of the ClassyModel. Args: @@ -270,11 +270,17 @@ def set_classy_state(self, state, strict=True): # fetched / set when there are no blocks attached. attached_heads = self.get_heads() self.clear_heads() - self.load_state_dict(state["model"]["trunk"], strict) + super().load_state_dict(state["model"]["trunk"], strict) # set the heads back again self.set_heads(attached_heads) + def get_classy_state(self, deep_copy=False): + return self.state_dict(deep_copy=deep_copy) + + def set_classy_state(self, state, strict=True): + self.load_state_dict(state, strict=strict) + def forward(self, x): """ Perform computation of blocks in the order define in get_blocks.