Skip to content

Commit 00087f7

Browse files
Jammy2211Jammy2211
authored andcommitted
enviroment variable JAX fast time check
1 parent 416bf1a commit 00087f7

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

autofit/jax_wrapper.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,42 @@
1+
import logging
2+
3+
logger = logging.getLogger(__name__)
4+
15
"""
26
Allows the user to switch between using NumPy and JAX for linear algebra operations.
37
48
If USE_JAX=true in general.yaml then JAX's NumPy is used, otherwise vanilla NumPy is used.
59
"""
6-
import jax
7-
import os
8-
910
from autoconf import conf
1011

1112
use_jax = conf.instance["general"]["jax"]["use_jax"]
1213

1314
if use_jax:
1415

16+
import os
17+
18+
xla_env = os.environ.get("XLA_FLAGS")
19+
20+
if "--xla_disable_hlo_passes=constant_folding" not in xla_env:
21+
logger.info(
22+
"""
23+
For fast JAX compile times, the envirment variable XLA_FLAGS must be set to "--xla_disable_hlo_passes=constant_folding",
24+
which is currently not.
25+
26+
In Python, to do this manually, use the code:
27+
28+
import os
29+
os.environ["XLA_FLAGS"] = "--xla_disable_hlo_passes=constant_folding"
30+
31+
The environment variable has been set automatically for you now, however if JAX has already been imported,
32+
this change will not take effect and JAX function compiling times may be slow.
33+
34+
Therefore, it is recommended to set this environment variable before running your script, e.g. in your terminal.
35+
""")
36+
37+
os.environ['XLA_FLAGS'] = "--xla_disable_hlo_passes=constant_folding"
38+
39+
import jax
1540
from jax import numpy
1641

1742
print(

0 commit comments

Comments
 (0)