Skip to content

Jit cache#23

Open
amithrm wants to merge 2 commits intolow_tp_stablefrom
jit_cache
Open

Jit cache#23
amithrm wants to merge 2 commits intolow_tp_stablefrom
jit_cache

Conversation

@amithrm
Copy link
Copy Markdown
Collaborator

@amithrm amithrm commented Nov 26, 2024

No description provided.

from axlearn.experiments.trainer_config_utils import TrainerConfigFn

import jax
import jax_neuronx
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax_neuronx might not be present when running on platforms other than TRN?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these already imported in neuron_attention? Will cache work given those imports?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the cache path set?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use the axlearn scripts to set it via an env variable:
srun /shared/amithrm/axlearn.run -e JAX_COMPILATION_CACHE_DIR="/mnt/jax_cache" -m /shared/amithrm/jax_cache_axlearn:/mnt/jax_cache -e TCMALLOC_TRANSFER_NUM_OBJ=32768 -e TCMALLOC_MAX_TOTAL_THREAD_CACHE_BYTES=4294967296 --rw

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax_neuronx change should go away once we use the latest jax. This is a temporary change that is monkey patching compiler cache.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amithrm can you list the PR to jax that enables "jax_neuronx change should go away once we use the latest jax"

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no PR...we won;t be needing these lines that patch compiler cache if we go to latest jax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants