Skip to content

Commit 9a2b2ac

Browse files
Element Panfacebook-github-bot
authored andcommitted
Add TRANSFORMER warmup policy for learning rate scheduling (#3548)
Summary: This diff implements the TRANSFORMER warmup policy from "Attention is All You Need" (Vaswani et al., 2017) for learning rate scheduling in torchrec, and updates a model configuration to use it. ## Implementation Added `WarmupPolicy.TRANSFORMER` to `fbcode/torchrec/optim/warmup.py` which implements the formula: ``` lr = base_lr * min(step^(-0.5), step * warm_steps^(-1.5)) * lr_scale ``` This schedule provides: - **Warmup phase**: LR increases from near-zero to peak at `warm_steps` - **Decay phase**: LR decreases via inverse square root after `warm_steps` The `max_iters` parameter serves as `warm_steps` in the formula. The schedule converges at step = warm_steps where both terms in the min() function become equal. ## Testing Added comprehensive unit tests in `fbcode/torchrec/optim/tests/test_warmup.py`: - Formula correctness at key milestones (step 1, warmup completion, post-warmup) - Monotonic increase during warmup phase - Monotonic decrease during decay phase - Proper application of lr_scale multiplier - Integration tests with WarmupOptimizer - Uses `none_throws()` from `pyre_extensions` for type-safe Optional handling Updated `fbcode/torchrec/optim/tests/BUCK` to include `pyre-extensions` dependency. ## Configuration Update Updated `fbcode/minimal_viable_ai/models/gysj/gysj_esr_roo/conf/model_roo_config.py` to use TRANSFORMER warmup: - Changed both sparse and dense optimizers from LINEAR warmup to TRANSFORMER - Set warm_steps to 80,000 for both optimizers - Updated optimizer hyperparameters (lr=0.001, eps=1e-6, beta values) to work with TRANSFORMER schedule - Sparse optimizer changed from ROWWISE_ADAGRAD to ADAM for consistency with dense optimizer Differential Revision: D87127589
1 parent 0677e23 commit 9a2b2ac

File tree

2 files changed

+316
-0
lines changed

2 files changed

+316
-0
lines changed

torchrec/optim/tests/test_warmup.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,160 @@ def step(self, closure: Any) -> None:
2626
pass # Override NotImplementedError.
2727

2828

29+
class TestGetMultiplier(unittest.TestCase):
30+
"""Tests for the _get_multiplier function with TRANSFORMER policy."""
31+
32+
def test_transformer_warmup_at_step_one(self) -> None:
33+
# Setup: Create TRANSFORMER warmup stage with warm_steps=4000
34+
stage = WarmupStage(
35+
policy=WarmupPolicy.TRANSFORMER,
36+
max_iters=4000,
37+
lr_scale=1.0,
38+
)
39+
40+
# Execute: Get multiplier at iteration 0 (step 1 internally)
41+
from torchrec.optim.warmup import _get_multiplier
42+
43+
multiplier = _get_multiplier(stage, iter=0)
44+
45+
# Assert: At step 1, multiplier should be min(1, 1/4000^1.5) ≈ 0.0000158
46+
# step^(-0.5) = 1^(-0.5) = 1.0
47+
# step * warm_steps^(-1.5) = 1 * 4000^(-1.5) ≈ 0.0000158
48+
expected = min(1.0, 1 * (4000 ** (-1.5)))
49+
self.assertAlmostEqual(multiplier, expected, places=8)
50+
self.assertLess(multiplier, 0.00002)
51+
52+
def test_transformer_warmup_at_warmup_steps(self) -> None:
53+
# Setup: Create TRANSFORMER warmup stage with warm_steps=4000
54+
stage = WarmupStage(
55+
policy=WarmupPolicy.TRANSFORMER,
56+
max_iters=4000,
57+
lr_scale=1.0,
58+
)
59+
60+
# Execute: Get multiplier at iteration 3999 (step 4000 internally)
61+
from torchrec.optim.warmup import _get_multiplier
62+
63+
multiplier = _get_multiplier(stage, iter=3999)
64+
65+
# Assert: At step=warm_steps, both terms are equal
66+
# step^(-0.5) = 4000^(-0.5) ≈ 0.0158
67+
# step * warm_steps^(-1.5) = 4000 * 4000^(-1.5) ≈ 0.0158
68+
step = 4000
69+
expected = min(step ** (-0.5), step * (4000 ** (-1.5)))
70+
self.assertAlmostEqual(multiplier, expected, places=8)
71+
self.assertAlmostEqual(multiplier, 0.0158114, places=6)
72+
73+
def test_transformer_warmup_after_warmup_steps(self) -> None:
74+
# Setup: Create TRANSFORMER warmup stage with warm_steps=4000
75+
stage = WarmupStage(
76+
policy=WarmupPolicy.TRANSFORMER,
77+
max_iters=4000,
78+
lr_scale=1.0,
79+
)
80+
81+
# Execute: Get multiplier at iteration 7999 (step 8000 internally)
82+
from torchrec.optim.warmup import _get_multiplier
83+
84+
multiplier = _get_multiplier(stage, iter=7999)
85+
86+
# Assert: After warmup, step^(-0.5) dominates (is smaller)
87+
# step^(-0.5) = 8000^(-0.5) ≈ 0.0112
88+
# step * warm_steps^(-1.5) = 8000 * 4000^(-1.5) ≈ 0.0316
89+
step = 8000
90+
inv_sqrt = step ** (-0.5)
91+
warmup_term = step * (4000 ** (-1.5))
92+
self.assertAlmostEqual(multiplier, inv_sqrt, places=8)
93+
self.assertLess(inv_sqrt, warmup_term)
94+
self.assertAlmostEqual(multiplier, 0.0111803, places=6)
95+
96+
def test_transformer_warmup_with_lr_scale(self) -> None:
97+
# Setup: Create TRANSFORMER warmup stage with lr_scale=2.0
98+
stage = WarmupStage(
99+
policy=WarmupPolicy.TRANSFORMER,
100+
max_iters=4000,
101+
lr_scale=2.0,
102+
)
103+
104+
# Execute: Get multiplier at iteration 3999 (step 4000 internally)
105+
from torchrec.optim.warmup import _get_multiplier
106+
107+
multiplier = _get_multiplier(stage, iter=3999)
108+
109+
# Assert: lr_scale is applied as a multiplier
110+
step = 4000
111+
base_multiplier = min(step ** (-0.5), step * (4000 ** (-1.5)))
112+
expected = base_multiplier * 2.0
113+
self.assertAlmostEqual(multiplier, expected, places=8)
114+
115+
def test_transformer_warmup_formula_correctness(self) -> None:
116+
# Setup: Create TRANSFORMER warmup stage with warm_steps=1000
117+
stage = WarmupStage(
118+
policy=WarmupPolicy.TRANSFORMER,
119+
max_iters=1000,
120+
lr_scale=1.0,
121+
)
122+
123+
# Execute: Test multiple iterations to verify formula
124+
from torchrec.optim.warmup import _get_multiplier
125+
126+
test_iters = [0, 99, 499, 999, 1999] # steps 1, 100, 500, 1000, 2000
127+
for iter_val in test_iters:
128+
multiplier = _get_multiplier(stage, iter=iter_val)
129+
step = iter_val + 1
130+
131+
# Assert: Multiplier matches the Transformer formula
132+
expected = min(step ** (-0.5), step * (1000 ** (-1.5)))
133+
self.assertAlmostEqual(
134+
multiplier,
135+
expected,
136+
places=8,
137+
msg=f"Failed at iteration {iter_val} (step {step})",
138+
)
139+
140+
def test_transformer_warmup_monotonic_increase_during_warmup(self) -> None:
141+
# Setup: Create TRANSFORMER warmup stage with warm_steps=1000
142+
stage = WarmupStage(
143+
policy=WarmupPolicy.TRANSFORMER,
144+
max_iters=1000,
145+
lr_scale=1.0,
146+
)
147+
148+
# Execute: Get multipliers during warmup phase
149+
from torchrec.optim.warmup import _get_multiplier
150+
151+
multipliers = [_get_multiplier(stage, iter=i) for i in range(0, 1000)]
152+
153+
# Assert: Multipliers should increase monotonically during warmup
154+
for idx in range(len(multipliers) - 1):
155+
self.assertLess(
156+
multipliers[idx],
157+
multipliers[idx + 1],
158+
msg=f"Multiplier should increase at iteration {idx}",
159+
)
160+
161+
def test_transformer_warmup_monotonic_decrease_after_warmup(self) -> None:
162+
# Setup: Create TRANSFORMER warmup stage with warm_steps=1000
163+
stage = WarmupStage(
164+
policy=WarmupPolicy.TRANSFORMER,
165+
max_iters=1000,
166+
lr_scale=1.0,
167+
)
168+
169+
# Execute: Get multipliers after warmup phase
170+
from torchrec.optim.warmup import _get_multiplier
171+
172+
multipliers = [_get_multiplier(stage, iter=i) for i in range(1000, 2000)]
173+
174+
# Assert: Multipliers should decrease monotonically after warmup
175+
for i in range(len(multipliers) - 1):
176+
self.assertGreater(
177+
multipliers[i],
178+
multipliers[i + 1],
179+
msg=f"Multiplier should decrease at iteration {i + 1000}",
180+
)
181+
182+
29183
class TestWarmupOptimizer(unittest.TestCase):
30184
def test_load_state_dict(self) -> None:
31185
def get_optimizer() -> WarmupOptimizer:
@@ -72,3 +226,155 @@ def get_optimizer() -> WarmupOptimizer:
72226
warmup_optimizer_1.state_dict()["state"]["__warmup"],
73227
warmup_optimizer_2.state_dict()["state"]["__warmup"],
74228
)
229+
230+
def test_transformer_warmup_integration(self) -> None:
231+
# Setup: Create optimizer with TRANSFORMER warmup policy
232+
param = Variable(torch.tensor([1.0, 2.0]))
233+
keyed_optimizer = DummyKeyedOptimizer(
234+
{"param": param}, defaultdict(dict), [{"params": [param]}]
235+
)
236+
237+
base_lr = 0.001
238+
warm_steps = 100
239+
240+
warmup_optimizer = WarmupOptimizer(
241+
keyed_optimizer,
242+
stages=[
243+
WarmupStage(
244+
policy=WarmupPolicy.TRANSFORMER,
245+
max_iters=100, # Stage ends at iteration 100
246+
lr_scale=1.0,
247+
),
248+
],
249+
lr=base_lr,
250+
)
251+
252+
# Execute: Run optimizer through warmup steps
253+
learning_rates = []
254+
current_lr = 0.0
255+
for _ in range(100): # Only iterate through the TRANSFORMER stage
256+
for param_group in warmup_optimizer.param_groups:
257+
current_lr = param_group["lr"]
258+
learning_rates.append(current_lr)
259+
warmup_optimizer.step()
260+
261+
# Assert: Verify learning rate follows Transformer schedule during warmup
262+
# At step 1 (iteration 0)
263+
step_1 = 1
264+
expected_lr_1 = base_lr * min(step_1 ** (-0.5), step_1 * (warm_steps ** (-1.5)))
265+
self.assertAlmostEqual(learning_rates[0], expected_lr_1, places=10)
266+
267+
# At step 50 (iteration 49) - mid-warmup
268+
step_50 = 50
269+
expected_lr_50 = base_lr * min(
270+
step_50 ** (-0.5), step_50 * (warm_steps ** (-1.5))
271+
)
272+
self.assertAlmostEqual(learning_rates[49], expected_lr_50, places=10)
273+
274+
# At step 100 (iteration 99) - warmup completion
275+
step_100 = 100
276+
expected_lr_100 = base_lr * min(
277+
step_100 ** (-0.5), step_100 * (warm_steps ** (-1.5))
278+
)
279+
self.assertAlmostEqual(learning_rates[99], expected_lr_100, places=10)
280+
281+
# Verify learning rate increases monotonically during warmup
282+
for idx in range(warm_steps - 1):
283+
self.assertLess(
284+
learning_rates[idx],
285+
learning_rates[idx + 1],
286+
msg=f"LR should increase during warmup at step {idx + 1}",
287+
)
288+
# Verify formula correctness at this step
289+
step = idx + 1
290+
expected_lr_at_idx = base_lr * min(
291+
step ** (-0.5), step * (warm_steps ** (-1.5))
292+
)
293+
self.assertAlmostEqual(
294+
learning_rates[idx],
295+
expected_lr_at_idx,
296+
places=10,
297+
msg=f"LR mismatch at step {step}",
298+
)
299+
300+
def test_transformer_warmup_with_extended_stage(self) -> None:
301+
# Setup: Create optimizer with TRANSFORMER stage to test warmup and decay
302+
param = Variable(torch.tensor([1.0, 2.0]))
303+
keyed_optimizer = DummyKeyedOptimizer(
304+
{"param": param}, defaultdict(dict), [{"params": [param]}]
305+
)
306+
307+
base_lr = 0.001
308+
# In the TRANSFORMER policy, max_iters acts as warm_steps in the formula
309+
max_iters = 8000 # Stage runs for 8000 iterations
310+
311+
warmup_optimizer = WarmupOptimizer(
312+
keyed_optimizer,
313+
stages=[
314+
WarmupStage(
315+
policy=WarmupPolicy.TRANSFORMER,
316+
max_iters=max_iters, # Stage runs for 8000 iterations
317+
lr_scale=1.0,
318+
),
319+
],
320+
lr=base_lr,
321+
)
322+
323+
# Execute: Run optimizer through warmup and decay phases
324+
current_lr = 0.0
325+
learning_rates = []
326+
for _ in range(max_iters):
327+
for param_group in warmup_optimizer.param_groups:
328+
current_lr = param_group["lr"]
329+
learning_rates.append(current_lr)
330+
warmup_optimizer.step()
331+
332+
# Assert: Verify the formula uses max_iters as warm_steps
333+
# At step 1, verify the formula: min(step^(-0.5), step * max_iters^(-1.5))
334+
step_1 = 1
335+
expected_lr_1 = base_lr * min(step_1 ** (-0.5), step_1 * (max_iters ** (-1.5)))
336+
self.assertAlmostEqual(
337+
learning_rates[0],
338+
expected_lr_1,
339+
places=10,
340+
msg=f"LR at step 1 should match formula with warm_steps={max_iters}",
341+
)
342+
343+
# At step 4000, verify with max_iters=8000
344+
step_4000 = 4000
345+
expected_lr_4000 = base_lr * min(
346+
step_4000 ** (-0.5), step_4000 * (max_iters ** (-1.5))
347+
)
348+
self.assertAlmostEqual(
349+
learning_rates[3999],
350+
expected_lr_4000,
351+
places=10,
352+
msg=f"LR at step 4000 should match formula with warm_steps={max_iters}",
353+
)
354+
355+
# At step max_iters (8000), both terms should be equal
356+
step_max = max_iters
357+
inv_sqrt = step_max ** (-0.5)
358+
warmup_term = step_max * (max_iters ** (-1.5))
359+
self.assertAlmostEqual(
360+
inv_sqrt,
361+
warmup_term,
362+
places=10,
363+
msg=f"At step={max_iters}, both formula terms should be equal",
364+
)
365+
366+
expected_lr_max = base_lr * min(inv_sqrt, warmup_term)
367+
self.assertAlmostEqual(
368+
learning_rates[max_iters - 1],
369+
expected_lr_max,
370+
places=10,
371+
msg=f"LR at step {max_iters} should match formula",
372+
)
373+
374+
# Verify learning rate increases before max_iters
375+
for idx in range(max_iters - 1):
376+
self.assertLess(
377+
learning_rates[idx],
378+
learning_rates[idx + 1],
379+
msg=f"LR should increase at step {idx + 1} (before max_iters={max_iters})",
380+
)

torchrec/optim/warmup.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ class WarmupPolicy(Enum):
2828
STEP = "step"
2929
INVSQRT = "inv_sqrt" # inverse square root
3030
COSINE_ANNEALING_WARM_RESTARTS = "cosine_annealing_warm_restarts"
31+
TRANSFORMER = (
32+
"transformer" # Transformer warmup: min(step^(-0.5), step * warm_steps^(-1.5))
33+
)
3134

3235

3336
@dataclass
@@ -86,6 +89,13 @@ def _get_multiplier(stage: WarmupStage, iter: int) -> float:
8689
t_cur = iter % t_0
8790
cos_iter = 0.5 * (1 + math.cos(math.pi * t_cur / t_0))
8891
multiplier = eta_min + (1.0 - eta_min) * cos_iter
92+
elif stage.policy == WarmupPolicy.TRANSFORMER:
93+
# Transformer warmup from "Attention is All You Need" (Vaswani et al., 2017)
94+
# Formula: lr_scale = min(step^(-0.5), step * warm_steps^(-1.5))
95+
# where warm_steps = max_iters
96+
# Add 1 to iter to make it 1-indexed and avoid division by zero
97+
step = iter + 1
98+
multiplier = min(step ** (-0.5), step * stage.max_iters ** (-1.5))
8999
return multiplier * stage.lr_scale
90100

91101

0 commit comments

Comments
 (0)