Skip to content

Commit 4d5bf24

Browse files
Refactor
1 parent 42a38f7 commit 4d5bf24

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

cortex/_lib/optimizer.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -146,26 +146,22 @@ def setup(model, optimizer='Adam', learning_rate=1.e-4,
146146
for p in params:
147147
p.requires_grad = True
148148

149-
# Learning rates
150-
if isinstance(learning_rate, dict):
151-
eta = learning_rate[network_key]
152-
else:
153-
eta = learning_rate
149+
def extract_value(dict_or_value, default=None):
150+
if isinstance(dict_or_value, dict):
151+
return dict_or_value.get(network_key, default)
152+
return dict_or_value
154153

154+
# Learning rates
155+
network_lr = extract_value(learning_rate)
155156
# Weight decay
156-
if isinstance(weight_decay, dict):
157-
wd = weight_decay.get(network_key, 0)
158-
else:
159-
wd = weight_decay
160-
161-
if isinstance(clipping, dict):
162-
cl = clipping.get(network_key, None)
163-
else:
164-
cl = clipping
157+
network_wd = extract_value(weight_decay, 0)
158+
# Gradient clipping
159+
network_cl = extract_value(clipping)
165160

166161
# Update the optimizer options
167162
optimizer_options_ = dict((k, v) for k, v in optimizer_options.items())
168-
optimizer_options_.update(weight_decay=wd, clipping=cl, lr=eta)
163+
optimizer_options_.update(
164+
weight_decay=network_wd, clipping=network_cl, lr=network_lr)
169165

170166
if network_key in model_optimizer_options:
171167
optimizer_options_.update(

0 commit comments

Comments
 (0)