@@ -26,6 +26,160 @@ def step(self, closure: Any) -> None:
2626 pass # Override NotImplementedError.
2727
2828
29+ class TestGetMultiplier (unittest .TestCase ):
30+ """Tests for the _get_multiplier function with TRANSFORMER policy."""
31+
32+ def test_transformer_warmup_at_step_one (self ) -> None :
33+ # Setup: Create TRANSFORMER warmup stage with warm_steps=4000
34+ stage = WarmupStage (
35+ policy = WarmupPolicy .TRANSFORMER ,
36+ max_iters = 4000 ,
37+ lr_scale = 1.0 ,
38+ )
39+
40+ # Execute: Get multiplier at iteration 0 (step 1 internally)
41+ from torchrec .optim .warmup import _get_multiplier
42+
43+ multiplier = _get_multiplier (stage , iter = 0 )
44+
45+ # Assert: At step 1, multiplier should be min(1, 1/4000^1.5) ≈ 0.0000158
46+ # step^(-0.5) = 1^(-0.5) = 1.0
47+ # step * warm_steps^(-1.5) = 1 * 4000^(-1.5) ≈ 0.0000158
48+ expected = min (1.0 , 1 * (4000 ** (- 1.5 )))
49+ self .assertAlmostEqual (multiplier , expected , places = 8 )
50+ self .assertLess (multiplier , 0.00002 )
51+
52+ def test_transformer_warmup_at_warmup_steps (self ) -> None :
53+ # Setup: Create TRANSFORMER warmup stage with warm_steps=4000
54+ stage = WarmupStage (
55+ policy = WarmupPolicy .TRANSFORMER ,
56+ max_iters = 4000 ,
57+ lr_scale = 1.0 ,
58+ )
59+
60+ # Execute: Get multiplier at iteration 3999 (step 4000 internally)
61+ from torchrec .optim .warmup import _get_multiplier
62+
63+ multiplier = _get_multiplier (stage , iter = 3999 )
64+
65+ # Assert: At step=warm_steps, both terms are equal
66+ # step^(-0.5) = 4000^(-0.5) ≈ 0.0158
67+ # step * warm_steps^(-1.5) = 4000 * 4000^(-1.5) ≈ 0.0158
68+ step = 4000
69+ expected = min (step ** (- 0.5 ), step * (4000 ** (- 1.5 )))
70+ self .assertAlmostEqual (multiplier , expected , places = 8 )
71+ self .assertAlmostEqual (multiplier , 0.0158114 , places = 6 )
72+
73+ def test_transformer_warmup_after_warmup_steps (self ) -> None :
74+ # Setup: Create TRANSFORMER warmup stage with warm_steps=4000
75+ stage = WarmupStage (
76+ policy = WarmupPolicy .TRANSFORMER ,
77+ max_iters = 4000 ,
78+ lr_scale = 1.0 ,
79+ )
80+
81+ # Execute: Get multiplier at iteration 7999 (step 8000 internally)
82+ from torchrec .optim .warmup import _get_multiplier
83+
84+ multiplier = _get_multiplier (stage , iter = 7999 )
85+
86+ # Assert: After warmup, step^(-0.5) dominates (is smaller)
87+ # step^(-0.5) = 8000^(-0.5) ≈ 0.0112
88+ # step * warm_steps^(-1.5) = 8000 * 4000^(-1.5) ≈ 0.0316
89+ step = 8000
90+ inv_sqrt = step ** (- 0.5 )
91+ warmup_term = step * (4000 ** (- 1.5 ))
92+ self .assertAlmostEqual (multiplier , inv_sqrt , places = 8 )
93+ self .assertLess (inv_sqrt , warmup_term )
94+ self .assertAlmostEqual (multiplier , 0.0111803 , places = 6 )
95+
96+ def test_transformer_warmup_with_lr_scale (self ) -> None :
97+ # Setup: Create TRANSFORMER warmup stage with lr_scale=2.0
98+ stage = WarmupStage (
99+ policy = WarmupPolicy .TRANSFORMER ,
100+ max_iters = 4000 ,
101+ lr_scale = 2.0 ,
102+ )
103+
104+ # Execute: Get multiplier at iteration 3999 (step 4000 internally)
105+ from torchrec .optim .warmup import _get_multiplier
106+
107+ multiplier = _get_multiplier (stage , iter = 3999 )
108+
109+ # Assert: lr_scale is applied as a multiplier
110+ step = 4000
111+ base_multiplier = min (step ** (- 0.5 ), step * (4000 ** (- 1.5 )))
112+ expected = base_multiplier * 2.0
113+ self .assertAlmostEqual (multiplier , expected , places = 8 )
114+
115+ def test_transformer_warmup_formula_correctness (self ) -> None :
116+ # Setup: Create TRANSFORMER warmup stage with warm_steps=1000
117+ stage = WarmupStage (
118+ policy = WarmupPolicy .TRANSFORMER ,
119+ max_iters = 1000 ,
120+ lr_scale = 1.0 ,
121+ )
122+
123+ # Execute: Test multiple iterations to verify formula
124+ from torchrec .optim .warmup import _get_multiplier
125+
126+ test_iters = [0 , 99 , 499 , 999 , 1999 ] # steps 1, 100, 500, 1000, 2000
127+ for iter_val in test_iters :
128+ multiplier = _get_multiplier (stage , iter = iter_val )
129+ step = iter_val + 1
130+
131+ # Assert: Multiplier matches the Transformer formula
132+ expected = min (step ** (- 0.5 ), step * (1000 ** (- 1.5 )))
133+ self .assertAlmostEqual (
134+ multiplier ,
135+ expected ,
136+ places = 8 ,
137+ msg = f"Failed at iteration { iter_val } (step { step } )" ,
138+ )
139+
140+ def test_transformer_warmup_monotonic_increase_during_warmup (self ) -> None :
141+ # Setup: Create TRANSFORMER warmup stage with warm_steps=1000
142+ stage = WarmupStage (
143+ policy = WarmupPolicy .TRANSFORMER ,
144+ max_iters = 1000 ,
145+ lr_scale = 1.0 ,
146+ )
147+
148+ # Execute: Get multipliers during warmup phase
149+ from torchrec .optim .warmup import _get_multiplier
150+
151+ multipliers = [_get_multiplier (stage , iter = i ) for i in range (0 , 1000 )]
152+
153+ # Assert: Multipliers should increase monotonically during warmup
154+ for idx in range (len (multipliers ) - 1 ):
155+ self .assertLess (
156+ multipliers [idx ],
157+ multipliers [idx + 1 ],
158+ msg = f"Multiplier should increase at iteration { idx } " ,
159+ )
160+
161+ def test_transformer_warmup_monotonic_decrease_after_warmup (self ) -> None :
162+ # Setup: Create TRANSFORMER warmup stage with warm_steps=1000
163+ stage = WarmupStage (
164+ policy = WarmupPolicy .TRANSFORMER ,
165+ max_iters = 1000 ,
166+ lr_scale = 1.0 ,
167+ )
168+
169+ # Execute: Get multipliers after warmup phase
170+ from torchrec .optim .warmup import _get_multiplier
171+
172+ multipliers = [_get_multiplier (stage , iter = i ) for i in range (1000 , 2000 )]
173+
174+ # Assert: Multipliers should decrease monotonically after warmup
175+ for i in range (len (multipliers ) - 1 ):
176+ self .assertGreater (
177+ multipliers [i ],
178+ multipliers [i + 1 ],
179+ msg = f"Multiplier should decrease at iteration { i + 1000 } " ,
180+ )
181+
182+
29183class TestWarmupOptimizer (unittest .TestCase ):
30184 def test_load_state_dict (self ) -> None :
31185 def get_optimizer () -> WarmupOptimizer :
@@ -72,3 +226,155 @@ def get_optimizer() -> WarmupOptimizer:
72226 warmup_optimizer_1 .state_dict ()["state" ]["__warmup" ],
73227 warmup_optimizer_2 .state_dict ()["state" ]["__warmup" ],
74228 )
229+
230+ def test_transformer_warmup_integration (self ) -> None :
231+ # Setup: Create optimizer with TRANSFORMER warmup policy
232+ param = Variable (torch .tensor ([1.0 , 2.0 ]))
233+ keyed_optimizer = DummyKeyedOptimizer (
234+ {"param" : param }, defaultdict (dict ), [{"params" : [param ]}]
235+ )
236+
237+ base_lr = 0.001
238+ warm_steps = 100
239+
240+ warmup_optimizer = WarmupOptimizer (
241+ keyed_optimizer ,
242+ stages = [
243+ WarmupStage (
244+ policy = WarmupPolicy .TRANSFORMER ,
245+ max_iters = 100 , # Stage ends at iteration 100
246+ lr_scale = 1.0 ,
247+ ),
248+ ],
249+ lr = base_lr ,
250+ )
251+
252+ # Execute: Run optimizer through warmup steps
253+ learning_rates = []
254+ current_lr = 0.0
255+ for _ in range (100 ): # Only iterate through the TRANSFORMER stage
256+ for param_group in warmup_optimizer .param_groups :
257+ current_lr = param_group ["lr" ]
258+ learning_rates .append (current_lr )
259+ warmup_optimizer .step ()
260+
261+ # Assert: Verify learning rate follows Transformer schedule during warmup
262+ # At step 1 (iteration 0)
263+ step_1 = 1
264+ expected_lr_1 = base_lr * min (step_1 ** (- 0.5 ), step_1 * (warm_steps ** (- 1.5 )))
265+ self .assertAlmostEqual (learning_rates [0 ], expected_lr_1 , places = 10 )
266+
267+ # At step 50 (iteration 49) - mid-warmup
268+ step_50 = 50
269+ expected_lr_50 = base_lr * min (
270+ step_50 ** (- 0.5 ), step_50 * (warm_steps ** (- 1.5 ))
271+ )
272+ self .assertAlmostEqual (learning_rates [49 ], expected_lr_50 , places = 10 )
273+
274+ # At step 100 (iteration 99) - warmup completion
275+ step_100 = 100
276+ expected_lr_100 = base_lr * min (
277+ step_100 ** (- 0.5 ), step_100 * (warm_steps ** (- 1.5 ))
278+ )
279+ self .assertAlmostEqual (learning_rates [99 ], expected_lr_100 , places = 10 )
280+
281+ # Verify learning rate increases monotonically during warmup
282+ for idx in range (warm_steps - 1 ):
283+ self .assertLess (
284+ learning_rates [idx ],
285+ learning_rates [idx + 1 ],
286+ msg = f"LR should increase during warmup at step { idx + 1 } " ,
287+ )
288+ # Verify formula correctness at this step
289+ step = idx + 1
290+ expected_lr_at_idx = base_lr * min (
291+ step ** (- 0.5 ), step * (warm_steps ** (- 1.5 ))
292+ )
293+ self .assertAlmostEqual (
294+ learning_rates [idx ],
295+ expected_lr_at_idx ,
296+ places = 10 ,
297+ msg = f"LR mismatch at step { step } " ,
298+ )
299+
300+ def test_transformer_warmup_with_extended_stage (self ) -> None :
301+ # Setup: Create optimizer with TRANSFORMER stage to test warmup and decay
302+ param = Variable (torch .tensor ([1.0 , 2.0 ]))
303+ keyed_optimizer = DummyKeyedOptimizer (
304+ {"param" : param }, defaultdict (dict ), [{"params" : [param ]}]
305+ )
306+
307+ base_lr = 0.001
308+ # In the TRANSFORMER policy, max_iters acts as warm_steps in the formula
309+ max_iters = 8000 # Stage runs for 8000 iterations
310+
311+ warmup_optimizer = WarmupOptimizer (
312+ keyed_optimizer ,
313+ stages = [
314+ WarmupStage (
315+ policy = WarmupPolicy .TRANSFORMER ,
316+ max_iters = max_iters , # Stage runs for 8000 iterations
317+ lr_scale = 1.0 ,
318+ ),
319+ ],
320+ lr = base_lr ,
321+ )
322+
323+ # Execute: Run optimizer through warmup and decay phases
324+ current_lr = 0.0
325+ learning_rates = []
326+ for _ in range (max_iters ):
327+ for param_group in warmup_optimizer .param_groups :
328+ current_lr = param_group ["lr" ]
329+ learning_rates .append (current_lr )
330+ warmup_optimizer .step ()
331+
332+ # Assert: Verify the formula uses max_iters as warm_steps
333+ # At step 1, verify the formula: min(step^(-0.5), step * max_iters^(-1.5))
334+ step_1 = 1
335+ expected_lr_1 = base_lr * min (step_1 ** (- 0.5 ), step_1 * (max_iters ** (- 1.5 )))
336+ self .assertAlmostEqual (
337+ learning_rates [0 ],
338+ expected_lr_1 ,
339+ places = 10 ,
340+ msg = f"LR at step 1 should match formula with warm_steps={ max_iters } " ,
341+ )
342+
343+ # At step 4000, verify with max_iters=8000
344+ step_4000 = 4000
345+ expected_lr_4000 = base_lr * min (
346+ step_4000 ** (- 0.5 ), step_4000 * (max_iters ** (- 1.5 ))
347+ )
348+ self .assertAlmostEqual (
349+ learning_rates [3999 ],
350+ expected_lr_4000 ,
351+ places = 10 ,
352+ msg = f"LR at step 4000 should match formula with warm_steps={ max_iters } " ,
353+ )
354+
355+ # At step max_iters (8000), both terms should be equal
356+ step_max = max_iters
357+ inv_sqrt = step_max ** (- 0.5 )
358+ warmup_term = step_max * (max_iters ** (- 1.5 ))
359+ self .assertAlmostEqual (
360+ inv_sqrt ,
361+ warmup_term ,
362+ places = 10 ,
363+ msg = f"At step={ max_iters } , both formula terms should be equal" ,
364+ )
365+
366+ expected_lr_max = base_lr * min (inv_sqrt , warmup_term )
367+ self .assertAlmostEqual (
368+ learning_rates [max_iters - 1 ],
369+ expected_lr_max ,
370+ places = 10 ,
371+ msg = f"LR at step { max_iters } should match formula" ,
372+ )
373+
374+ # Verify learning rate increases before max_iters
375+ for idx in range (max_iters - 1 ):
376+ self .assertLess (
377+ learning_rates [idx ],
378+ learning_rates [idx + 1 ],
379+ msg = f"LR should increase at step { idx + 1 } (before max_iters={ max_iters } )" ,
380+ )
0 commit comments