File tree Expand file tree Collapse file tree 1 file changed +28
-3
lines changed
Expand file tree Collapse file tree 1 file changed +28
-3
lines changed Original file line number Diff line number Diff line change 1+ import logging
2+
3+ logger = logging .getLogger (__name__ )
4+
15"""
26Allows the user to switch between using NumPy and JAX for linear algebra operations.
37
48If USE_JAX=true in general.yaml then JAX's NumPy is used, otherwise vanilla NumPy is used.
59"""
6- import jax
7- import os
8-
910from autoconf import conf
1011
1112use_jax = conf .instance ["general" ]["jax" ]["use_jax" ]
1213
1314if use_jax :
1415
16+ import os
17+
18+ xla_env = os .environ .get ("XLA_FLAGS" )
19+
20+ if "--xla_disable_hlo_passes=constant_folding" not in xla_env :
21+ logger .info (
22+ """
23+ For fast JAX compile times, the envirment variable XLA_FLAGS must be set to "--xla_disable_hlo_passes=constant_folding",
24+ which is currently not.
25+
26+ In Python, to do this manually, use the code:
27+
28+ import os
29+ os.environ["XLA_FLAGS"] = "--xla_disable_hlo_passes=constant_folding"
30+
31+ The environment variable has been set automatically for you now, however if JAX has already been imported,
32+ this change will not take effect and JAX function compiling times may be slow.
33+
34+ Therefore, it is recommended to set this environment variable before running your script, e.g. in your terminal.
35+ """ )
36+
37+ os .environ ['XLA_FLAGS' ] = "--xla_disable_hlo_passes=constant_folding"
38+
39+ import jax
1540 from jax import numpy
1641
1742 print (
You can’t perform that action at this time.
0 commit comments