@@ -142,26 +142,22 @@ def setup(model, optimizer='Adam', learning_rate=1.e-4,
142142 for p in params :
143143 p .requires_grad = True
144144
145- # Learning rates
146- if isinstance (learning_rate , dict ):
147- eta = learning_rate [network_key ]
148- else :
149- eta = learning_rate
145+ def extract_value (dict_or_value , default = None ):
146+ if isinstance (dict_or_value , dict ):
147+ return dict_or_value .get (network_key , default )
148+ return dict_or_value
150149
150+ # Learning rates
151+ network_lr = extract_value (learning_rate )
151152 # Weight decay
152- if isinstance (weight_decay , dict ):
153- wd = weight_decay .get (network_key , 0 )
154- else :
155- wd = weight_decay
156-
157- if isinstance (clipping , dict ):
158- cl = clipping .get (network_key , None )
159- else :
160- cl = clipping
153+ network_wd = extract_value (weight_decay , 0 )
154+ # Gradient clipping
155+ network_cl = extract_value (clipping )
161156
162157 # Update the optimizer options
163158 optimizer_options_ = dict ((k , v ) for k , v in optimizer_options .items ())
164- optimizer_options_ .update (weight_decay = wd , clipping = cl , lr = eta )
159+ optimizer_options_ .update (
160+ weight_decay = network_wd , clipping = network_cl , lr = network_lr )
165161
166162 if network_key in model_optimizer_options :
167163 optimizer_options_ .update (
0 commit comments