Skip to content

Minor bug in setting the precision #185

@prockenschaub

Description

@prockenschaub

Problem

Always-true conditional in train_common forces set_float32_matmul_precision("medium") regardless of precision setting.

if precision == 16 or "16-mixed":
torch.set_float32_matmul_precision("medium")

Solution

Compare to the intended precision values, and consider "high" instead of "medium" since "high" still permits TF32 but keeps fp32 accumulations where possible.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions