fix(reproducibility): add opt-in strict determinism across trainers#61
fix(reproducibility): add opt-in strict determinism across trainers#61
Conversation
g-braeunlich
left a comment
There was a problem hiding this comment.
Just 2 small suggestions
| if torch.cuda.is_available(): | ||
| torch.backends.cuda.matmul.allow_tf32 = False |
There was a problem hiding this comment.
Not exactly the same logic, but maybe safer?
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cuda.matmul.allow_tf32 = not torch.cuda.is_available() |
There was a problem hiding this comment.
I kept the current logic intentionally for strict mode semantics. In strict mode we should avoid enabling TF32; not torch.cuda.is_available() would set this flag to True on CPU/MPS runs. Current behavior only disables CUDA matmul TF32 when CUDA exists, while CPU/MPS remain unaffected. If we simplify, I’d still keep the assigned value False and only guard backend availability.
There was a problem hiding this comment.
Follow-up: I agree the cleaner framing is backend-capability based. For strict mode, we should keep TF32 disabled (False) whenever the CUDA matmul backend is present, rather than deriving the value from torch.cuda.is_available() (which can be False on CPU/MPS runs and would imply True). Happy to switch to that form for clarity if you prefer.
Description
Adds an opt-in strict reproducibility path for all training entrypoints while preserving current default behavior.
engiopt/reproducibility.pywith shared helpers:seed_training(seed)enable_strict_determinism(warn_only=True)make_dataloader_generator(seed)strict_determinism: bool = Falseto all targeted trainingArgsdataclasses.seed_training(args.seed)in each training script.--strict-determinismis passed.shuffle=Trueloaders.--strict-determinism).Fixes SOH-14 (Linear)
Type of change
Screenshots
N/A
Checklist:
Code Quality
pre-commitchecks withpre-commit run --all-filesruff check .andruff formatmypy .CleanRL Philosophy (for new/modified algorithms)
tyro--trackflag support--save-modelflag)Algorithm Completeness (for new algorithms)
algorithm.py) and evaluation script (evaluate_algorithm.py) are providedProbleminterfaceDocumentation
Validation
cgan_cnn_2d) run twice on same machine with strict mode and fixed seed; resulting checkpoint tensor hashes matched for both generator and discriminator.warn_only=Truefor nondeterministic ops to warn and continue.