Skip to content

Commit 2ac80d0

Browse files
fix tests for tf 2.11 (#2783)
1 parent 1f14395 commit 2ac80d0

File tree

4 files changed

+24
-15
lines changed

4 files changed

+24
-15
lines changed

tensorflow_addons/callbacks/tests/avg_model_checkpoint_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,16 @@
1313
EPOCHS = 5
1414

1515

16+
def get_legacy_sgd(learning_rate):
17+
if hasattr(tf.keras.optimizers, "legacy"):
18+
return tf.keras.optimizers.legacy.SGD(learning_rate)
19+
return tf.keras.optimizers.SGD(learning_rate)
20+
21+
1622
def get_data_and_model(optimizer="moving_avg"):
1723
x = tf.random.normal([TRAIN_SAMPLES, INPUT_DIM])
1824
y = tf.random.normal([TRAIN_SAMPLES, NUM_CLASSES])
19-
moving_avg = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
25+
moving_avg = MovingAverage(get_legacy_sgd(2.0), average_decay=0.5)
2026
if optimizer == "moving_avg":
2127
optimizer = moving_avg
2228
inputs = keras.layers.Input(INPUT_DIM)
@@ -199,7 +205,7 @@ def test_invalid_save_freq(tmp_path):
199205

200206
def test_loss_scale_optimizer(tmp_path):
201207
test_model_filepath = str(tmp_path / "test_model.{epoch:02d}.h5")
202-
moving_avg = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
208+
moving_avg = MovingAverage(get_legacy_sgd(2.0), average_decay=0.5)
203209
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(moving_avg)
204210
x, y, model = get_data_and_model(optimizer)
205211
save_freq = "epoch"

tensorflow_addons/optimizers/tests/moving_average_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,12 @@ def test_start_step():
222222
grads0 = tf.constant([0.1, 0.1])
223223
grads_and_vars = [(grads0, var0)]
224224

225-
opt = MovingAverage(
226-
tf.keras.optimizers.SGD(lr=1.0), average_decay=0.5, start_step=1
227-
)
225+
if hasattr(tf.keras.optimizers, "legacy"):
226+
sgd_opt = tf.keras.optimizers.legacy.SGD(lr=1.0)
227+
else:
228+
sgd_opt = tf.keras.optimizers.SGD(lr=1.0)
229+
230+
opt = MovingAverage(sgd_opt, average_decay=0.5, start_step=1)
228231

229232
opt.apply_gradients(grads_and_vars)
230233

tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,7 @@ def test_optimizer_basic(dtype, optimizer):
431431
"optimizer",
432432
[
433433
weight_decay_optimizers.SGDW,
434-
weight_decay_optimizers.extend_with_decoupled_weight_decay(
435-
tf.keras.optimizers.SGD
436-
),
434+
weight_decay_optimizers.extend_with_decoupled_weight_decay(optimizer_class),
437435
],
438436
)
439437
@pytest.mark.parametrize("dtype", [tf.half, tf.float32, tf.float64])

tensorflow_addons/optimizers/weight_decay_optimizers.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,16 @@ def __init__(
371371
return OptimizerWithDecoupledWeightDecay
372372

373373

374+
if hasattr(tf.keras.optimizers, "legacy"):
375+
ADAM_CLASS = tf.keras.optimizers.legacy.Adam
376+
SGD_CLASS = tf.keras.optimizers.legacy.SGD
377+
else:
378+
ADAM_CLASS = tf.keras.optimizers.Adam
379+
SGD_CLASS = tf.keras.optimizers.SGD
380+
381+
374382
@tf.keras.utils.register_keras_serializable(package="Addons")
375-
class SGDW(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD):
383+
class SGDW(DecoupledWeightDecayExtension, SGD_CLASS):
376384
"""Optimizer that implements the Momentum algorithm with weight_decay.
377385
378386
This is an implementation of the SGDW optimizer described in "Decoupled
@@ -450,12 +458,6 @@ def __init__(
450458
)
451459

452460

453-
if hasattr(tf.keras.optimizers, "legacy"):
454-
ADAM_CLASS = tf.keras.optimizers.legacy.Adam
455-
else:
456-
ADAM_CLASS = tf.keras.optimizers.Adam
457-
458-
459461
@tf.keras.utils.register_keras_serializable(package="Addons")
460462
class AdamW(DecoupledWeightDecayExtension, ADAM_CLASS):
461463
"""Optimizer that implements the Adam algorithm with weight decay.

0 commit comments

Comments
 (0)