Commit fef84a0
committed
Optimize Orbax checkpoint for JAX backend with compatibility check
- Preserve JAX arrays during saving when jax.monitoring.record_scalar is available
- Fall back to numpy conversion for older JAX versions that don't have record_scalar
- Maintain cross-backend compatibility while avoiding unnecessary conversions
- Update async waiting to use CheckpointManager.wait_until_finished()
- Implement AlwaysSavePolicy for reliable save decisions
- Add expected failures for sklearn tests due to neural network non-determinism1 parent b7a0dff commit fef84a0
1 file changed
+10
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
19 | | - | |
20 | | - | |
| 19 | + | |
| 20 | + | |
21 | 21 | | |
22 | | - | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
23 | 30 | | |
24 | 31 | | |
25 | 32 | | |
| |||
0 commit comments