From e833da0f5652cd7964ce48680d8ad83772a3a674 Mon Sep 17 00:00:00 2001 From: Laura Gustafson Date: Thu, 25 Mar 2021 11:47:44 -0700 Subject: [PATCH] Support freezing model anywhere in fine tuning Summary: Add support for freezing the model anywhere in the fine tuning task. Users can specify a specific module to freeze the model until in finetuning. Functionality is useful for situations, like the FixRes paper, where both the head and the last batch norm layer are trained during fine tuning. Example fblearner run using freeze_until to freeze the trunk model: f259575699 Example fblearner run using freeze_until to unfreeze the last batchnorm and head: f259575306 - Adds new config option `freeze_until` to specify what point to freeze the model to. Options are `head` or the name of a module in the model. The model will be frozen until but not including that module and unfrozen at that point onwards. `freeze_until: 'head'` has the same functionality as `freeze_trunk: true`. - Adds documentation for fine tuning task Differential Revision: D27199092 fbshipit-source-id: b12dc00563da45806317f60e2abbcc4237bec94c --- classy_vision/tasks/fine_tuning_task.py | 87 +++++++++++++++++++---- test/generic/utils.py | 9 ++- test/tasks_fine_tuning_task_test.py | 91 ++++++++++++++++++++++--- 3 files changed, 164 insertions(+), 23 deletions(-) diff --git a/classy_vision/tasks/fine_tuning_task.py b/classy_vision/tasks/fine_tuning_task.py index bd30e29c7d..62eb53a5f4 100644 --- a/classy_vision/tasks/fine_tuning_task.py +++ b/classy_vision/tasks/fine_tuning_task.py @@ -4,7 +4,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict +import warnings +from enum import Enum +from typing import Any, Callable, Dict, Union from classy_vision.generic.util import ( load_and_broadcast_checkpoint, @@ -13,15 +15,42 @@ from classy_vision.tasks import ClassificationTask, register_task +class FreezeUntil(Enum): + """ + Enum for a pre-specified point to freeze the classy model unitl. + + Attributes: + HEAD (str): Freeze the model unitl the classy head + """ + + HEAD = "head" + + def __eq__(self, other: str): + return other.lower() == self.value + + @register_task("fine_tuning") class FineTuningTask(ClassificationTask): + """Finetuning training task. + + This task encapsultates all of the components and steps needed to + fine-tune a classifier using a :class:`classy_vision.trainer.ClassyTrainer`. + + :var pretrained_checkpoint_path: String path to pretrained model + :var reset_heads: bool. Whether or not to reset the model heads during finetuning. + :var freeze_until: optional string. If specified, must be a string name of a module within + the model. Finetuning will freeze the model up to this module. Model weights will + only be trainable from this modeule onwards, always including the head. To freeze the + trunk model, specify 'head' as the un-freeze point. + """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.pretrained_checkpoint_dict = None self.pretrained_checkpoint_path = None self.pretrained_checkpoint_load_strict = True self.reset_heads = False - self.freeze_trunk = False + self.freeze_until = None @classmethod def from_config(cls, config: Dict[str, Any]) -> "FineTuningTask": @@ -44,7 +73,13 @@ def from_config(cls, config: Dict[str, Any]) -> "FineTuningTask": ) task.set_reset_heads(config.get("reset_heads", False)) - task.set_freeze_trunk(config.get("freeze_trunk", False)) + assert ( + "freeze_trunk" not in config or "freeze_until" not in config + ), "Config options 'freeze_trunk' and 'freeze_until' cannot both be specified" + if "freeze_trunk" in config: + task.set_freeze_trunk(config.get("freeze_trunk", False)) + else: + task.set_freeze_until(config.get("freeze_until", None)) return task def set_pretrained_checkpoint(self, checkpoint_path: str) -> "FineTuningTask": @@ -68,22 +103,46 @@ def set_reset_heads(self, reset_heads: bool) -> "FineTuningTask": return self def set_freeze_trunk(self, freeze_trunk: bool) -> "FineTuningTask": - self.freeze_trunk = freeze_trunk + if freeze_trunk: + self.freeze_until = FreezeUntil.HEAD.value + warnings.warn( + "Congig option freeze_trunk has been deprecated. " + "Use \"freeze_until:'head'\" instead", + DeprecationWarning, + ) + + return self + + def set_freeze_until(self, freeze_until: Union[str, None]) -> "FineTuningTask": + self.freeze_until = freeze_until return self def _set_model_train_mode(self): phase = self.phases[self.phase_idx] self.loss.train(phase["train"]) - if self.freeze_trunk: + if self.freeze_until is not None: # convert all the sub-modules to the eval mode, except the heads self.base_model.eval() - for heads in self.base_model.get_heads().values(): - for h in heads: - h.train(phase["train"]) + self._apply_to_nonfrozen(lambda x: x.train(phase["train"])) else: self.base_model.train(phase["train"]) + def _apply_to_nonfrozen(self, callable: Callable[..., Any]) -> None: + for heads in self.base_model.get_heads().values(): + for h in heads: + callable(h) + if not self.freeze_until == FreezeUntil.HEAD: + unfrozen_module = False + for name, module in self.base_model.named_modules(): + if name == self.freeze_until: + unfrozen_module = True + if unfrozen_module: + callable(module) + assert ( + unfrozen_module + ), f"Freeze until point {self.freeze_until} not found in model" + def prepare(self) -> None: super().prepare() if self.checkpoint_dict is None: @@ -109,15 +168,17 @@ def prepare(self) -> None: state_load_success ), "Update classy state from pretrained checkpoint was unsuccessful." - if self.freeze_trunk: + if self.freeze_until is not None: # do not track gradients for all the parameters in the model except # for the parameters in the heads for param in self.base_model.parameters(): param.requires_grad = False - for heads in self.base_model.get_heads().values(): - for h in heads: - for param in h.parameters(): - param.requires_grad = True + + def _set_requires_grad_true(x): + for param in x.parameters(): + param.requires_grad = True + + self._apply_to_nonfrozen(_set_requires_grad_true) # re-create ddp model self.distributed_model = None self.init_distributed_data_parallel_model() diff --git a/test/generic/utils.py b/test/generic/utils.py index c05418cba9..d126748f65 100644 --- a/test/generic/utils.py +++ b/test/generic/utils.py @@ -215,8 +215,15 @@ def recursive_unpack(batch): raise TypeError("Unexpected type %s passed to unpack" % type(batch)) -def compare_model_state(test_fixture, state, state2, check_heads=True): +def compare_model_state( + test_fixture, state, state2, check_heads=True, state_changed_params=() +): for k in state["model"]["trunk"].keys(): + if k in state_changed_params: + test_fixture.assertFalse( + torch.allclose(state["model"]["trunk"][k], state2["model"]["trunk"][k]) + ) + continue if not torch.allclose(state["model"]["trunk"][k], state2["model"]["trunk"][k]): print(k, state["model"]["trunk"][k], state2["model"]["trunk"][k]) test_fixture.assertTrue( diff --git a/test/tasks_fine_tuning_task_test.py b/test/tasks_fine_tuning_task_test.py index b4a444d64b..7b8f48f795 100644 --- a/test/tasks_fine_tuning_task_test.py +++ b/test/tasks_fine_tuning_task_test.py @@ -40,13 +40,47 @@ def forward(self, x, target): class TestFineTuningTask(unittest.TestCase): - def _compare_model_state(self, state_1, state_2, check_heads=True): - return compare_model_state(self, state_1, state_2, check_heads=check_heads) + def _compare_model_state( + self, state_1, state_2, check_heads=True, state_changed_params=() + ): + return compare_model_state( + self, + state_1, + state_2, + check_heads=check_heads, + state_changed_params=state_changed_params, + ) def _compare_state_dict(self, state_1, state_2, check_heads=True): for k in state_1.keys(): self.assertTrue(torch.allclose(state_1[k].cpu(), state_2[k].cpu())) + def _get_unfrezee_points_to_unfrozen_params(self): + return { + "blocks.0.block0-0._module.downsample.1": ( + "blocks.0.block0-0.downsample.1.weight", + "blocks.0.block0-0.downsample.1.bias", + "blocks.0.block0-0.downsample.1.running_mean", + "blocks.0.block0-0.downsample.1.running_var", + "blocks.0.block0-0.downsample.1.num_batches_tracked", + ), + "blocks.0.block0-0._module.convolutional_block.6": ( + "blocks.0.block0-0.convolutional_block.6.weight", + "blocks.0.block0-0.bn.weight", + "blocks.0.block0-0.bn.bias", + "blocks.0.block0-0.bn.running_mean", + "blocks.0.block0-0.bn.running_var", + "blocks.0.block0-0.bn.num_batches_tracked", + "blocks.0.block0-0.downsample.0.weight", + "blocks.0.block0-0.downsample.1.weight", + "blocks.0.block0-0.downsample.1.bias", + "blocks.0.block0-0.downsample.1.running_mean", + "blocks.0.block0-0.downsample.1.running_var", + "blocks.0.block0-0.downsample.1.num_batches_tracked", + ), + "head": (), + } + def _get_fine_tuning_config( self, head_num_classes=100, pretrained_checkpoint=False ): @@ -152,19 +186,22 @@ def test_train(self): trainer = LocalTrainer() trainer.train(pre_train_task) checkpoint = get_checkpoint_dict(pre_train_task, {}) - + unfreeze_points = self._get_unfrezee_points_to_unfrozen_params() + unfreeze_options = list(unfreeze_points) + [None] for reset_heads, heads_num_classes in [(False, 100), (True, 20)]: - for freeze_trunk in [True, False]: - fine_tuning_config = self._get_fine_tuning_config( - head_num_classes=heads_num_classes + for unfreeze_point in unfreeze_options: + fine_tuning_config = copy.deepcopy( + self._get_fine_tuning_config(head_num_classes=heads_num_classes) ) + # Extra epochs helps ensure that unfrozen parameters change value + fine_tuning_config["num_epochs"] = 4 fine_tuning_task = build_task(fine_tuning_config) fine_tuning_task = ( fine_tuning_task._set_pretrained_checkpoint_dict( copy.deepcopy(checkpoint) ) .set_reset_heads(reset_heads) - .set_freeze_trunk(freeze_trunk) + .set_freeze_until(unfreeze_point) ) # run in test mode to compare the model state fine_tuning_task.set_test_only(True) @@ -177,12 +214,14 @@ def test_train(self): # run in train mode to check accuracy fine_tuning_task.set_test_only(False) trainer.train(fine_tuning_task) - if freeze_trunk: - # if trunk is frozen the states should be the same + if unfreeze_point is not None: + # check that expected part of model is frozen + # and unfrozen part isn't frozen self._compare_model_state( pre_train_task.model.get_classy_state(), fine_tuning_task.model.get_classy_state(), check_heads=False, + state_changed_params=unfreeze_points[unfreeze_point], ) else: # trunk isn't frozen, the states should be different @@ -196,6 +235,40 @@ def test_train(self): accuracy = fine_tuning_task.meters[0].value["top_1"] self.assertAlmostEqual(accuracy, 1.0) + def test_freeze_trunk_backwards_compatability(self): + pre_train_config = self._get_pre_train_config(head_num_classes=100) + pre_train_task = build_task(pre_train_config) + trainer = LocalTrainer() + trainer.train(pre_train_task) + checkpoint = get_checkpoint_dict(pre_train_task, {}) + for reset_heads, heads_num_classes in [(False, 100), (True, 20)]: + fine_tuning_config = copy.deepcopy( + self._get_fine_tuning_config(head_num_classes=heads_num_classes) + ) + fine_tuning_config["freeze_trunk"] = True + with self.assertWarns(DeprecationWarning): + fine_tuning_task = build_task(fine_tuning_config) + fine_tuning_task = fine_tuning_task._set_pretrained_checkpoint_dict( + copy.deepcopy(checkpoint) + ).set_reset_heads(reset_heads) + fine_tuning_task.set_test_only(True) + trainer.train(fine_tuning_task) + self._compare_model_state( + pre_train_task.model.get_classy_state(), + fine_tuning_task.model.get_classy_state(), + check_heads=not reset_heads, + ) + # run in train mode to check accuracy + fine_tuning_task.set_test_only(False) + trainer.train(fine_tuning_task) + self._compare_model_state( + pre_train_task.model.get_classy_state(), + fine_tuning_task.model.get_classy_state(), + check_heads=False, + ) + accuracy = fine_tuning_task.meters[0].value["top_1"] + self.assertAlmostEqual(accuracy, 1.0) + def test_train_parametric_loss(self): heads_num_classes = 100 pre_train_config = self._get_pre_train_config(