@@ -146,26 +146,22 @@ def setup(model, optimizer='Adam', learning_rate=1.e-4,
146146 for p in params :
147147 p .requires_grad = True
148148
149- # Learning rates
150- if isinstance (learning_rate , dict ):
151- eta = learning_rate [network_key ]
152- else :
153- eta = learning_rate
149+ def extract_value (dict_or_value , default = None ):
150+ if isinstance (dict_or_value , dict ):
151+ return dict_or_value .get (network_key , default )
152+ return dict_or_value
154153
154+ # Learning rates
155+ network_lr = extract_value (learning_rate )
155156 # Weight decay
156- if isinstance (weight_decay , dict ):
157- wd = weight_decay .get (network_key , 0 )
158- else :
159- wd = weight_decay
160-
161- if isinstance (clipping , dict ):
162- cl = clipping .get (network_key , None )
163- else :
164- cl = clipping
157+ network_wd = extract_value (weight_decay , 0 )
158+ # Gradient clipping
159+ network_cl = extract_value (clipping )
165160
166161 # Update the optimizer options
167162 optimizer_options_ = dict ((k , v ) for k , v in optimizer_options .items ())
168- optimizer_options_ .update (weight_decay = wd , clipping = cl , lr = eta )
163+ optimizer_options_ .update (
164+ weight_decay = network_wd , clipping = network_cl , lr = network_lr )
169165
170166 if network_key in model_optimizer_options :
171167 optimizer_options_ .update (
0 commit comments