Jit cache#23
Conversation
| from axlearn.experiments.trainer_config_utils import TrainerConfigFn | ||
|
|
||
| import jax | ||
| import jax_neuronx |
There was a problem hiding this comment.
jax_neuronx might not be present when running on platforms other than TRN?
There was a problem hiding this comment.
Are these already imported in neuron_attention? Will cache work given those imports?
There was a problem hiding this comment.
Where is the cache path set?
There was a problem hiding this comment.
I use the axlearn scripts to set it via an env variable:
srun /shared/amithrm/axlearn.run -e JAX_COMPILATION_CACHE_DIR="/mnt/jax_cache" -m /shared/amithrm/jax_cache_axlearn:/mnt/jax_cache -e TCMALLOC_TRANSFER_NUM_OBJ=32768 -e TCMALLOC_MAX_TOTAL_THREAD_CACHE_BYTES=4294967296 --rw
There was a problem hiding this comment.
jax_neuronx change should go away once we use the latest jax. This is a temporary change that is monkey patching compiler cache.
There was a problem hiding this comment.
@amithrm can you list the PR to jax that enables "jax_neuronx change should go away once we use the latest jax"
There was a problem hiding this comment.
there is no PR...we won;t be needing these lines that patch compiler cache if we go to latest jax
No description provided.