Skip to content

Commit c5c8711

Browse files
Refactor
1 parent ecdbd05 commit c5c8711

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

cortex/_lib/optimizer.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -142,26 +142,22 @@ def setup(model, optimizer='Adam', learning_rate=1.e-4,
142142
for p in params:
143143
p.requires_grad = True
144144

145-
# Learning rates
146-
if isinstance(learning_rate, dict):
147-
eta = learning_rate[network_key]
148-
else:
149-
eta = learning_rate
145+
def extract_value(dict_or_value, default=None):
146+
if isinstance(dict_or_value, dict):
147+
return dict_or_value.get(network_key, default)
148+
return dict_or_value
150149

150+
# Learning rates
151+
network_lr = extract_value(learning_rate)
151152
# Weight decay
152-
if isinstance(weight_decay, dict):
153-
wd = weight_decay.get(network_key, 0)
154-
else:
155-
wd = weight_decay
156-
157-
if isinstance(clipping, dict):
158-
cl = clipping.get(network_key, None)
159-
else:
160-
cl = clipping
153+
network_wd = extract_value(weight_decay, 0)
154+
# Gradient clipping
155+
network_cl = extract_value(clipping)
161156

162157
# Update the optimizer options
163158
optimizer_options_ = dict((k, v) for k, v in optimizer_options.items())
164-
optimizer_options_.update(weight_decay=wd, clipping=cl, lr=eta)
159+
optimizer_options_.update(
160+
weight_decay=network_wd, clipping=network_cl, lr=network_lr)
165161

166162
if network_key in model_optimizer_options:
167163
optimizer_options_.update(

0 commit comments

Comments
 (0)