From 2f3e2a99206ce97d9097633f6c4dc4a78547e51f Mon Sep 17 00:00:00 2001 From: Zhicheng Yan Date: Mon, 19 Oct 2020 21:59:04 -0700 Subject: [PATCH] support EMA hook in standalone trainer Summary: Currently, EMA hook is not used by the classy vision standalone trainer. Thus, add an argument `ema_decay` to use it when users sets it to a positive number. Differential Revision: D24382231 fbshipit-source-id: 76863305d063662f764dc0f9dd0371b8e2bf40e8 --- classy_train.py | 3 +++ classy_vision/generic/opts.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/classy_train.py b/classy_train.py index e693220057..7b036cbc80 100755 --- a/classy_train.py +++ b/classy_train.py @@ -49,6 +49,7 @@ from classy_vision.generic.util import load_checkpoint, load_json from classy_vision.hooks import ( CheckpointHook, + ExponentialMovingAverageModelHook, LossLrMeterLoggingHook, ModelComplexityHook, ProfilerHook, @@ -152,6 +153,8 @@ def configure_hooks(args, config): hooks.append(ProgressBarHook()) if args.visdom_server != "": hooks.append(VisdomHook(args.visdom_server, args.visdom_port)) + if args.ema_decay > 0: + hooks.append(ExponentialMovingAverageModelHook(args.ema_decay)) return hooks diff --git a/classy_vision/generic/opts.py b/classy_vision/generic/opts.py index e4d9cc5e0c..19d98b2461 100644 --- a/classy_vision/generic/opts.py +++ b/classy_vision/generic/opts.py @@ -121,6 +121,12 @@ def add_generic_args(parser): help="""Distributed backend: either 'none' (for non-distributed runs) or 'ddp' (for distributed runs). Default none.""", ) + parser.add_argument( + "--ema_decay", + default=0, + type=float, + help="""Decay rate of model Exponential Moving Averaging""", + ) return parser