Problem
Always-true conditional in train_common forces set_float32_matmul_precision("medium") regardless of precision setting.
|
if precision == 16 or "16-mixed": |
|
torch.set_float32_matmul_precision("medium") |
Solution
Compare to the intended precision values, and consider "high" instead of "medium" since "high" still permits TF32 but keeps fp32 accumulations where possible.
Problem
Always-true conditional in train_common forces set_float32_matmul_precision("medium") regardless of precision setting.
YAIB/icu_benchmarks/models/train.py
Lines 157 to 158 in ef999f7
Solution
Compare to the intended precision values, and consider "high" instead of "medium" since "high" still permits TF32 but keeps fp32 accumulations where possible.