Skip to content

imported os and added ckpt scripts#41

Open
dgourab-aws wants to merge 4 commits intoaccuracy_workstream_trnfrom
ckpt_restore_dataloader_fix
Open

imported os and added ckpt scripts#41
dgourab-aws wants to merge 4 commits intoaccuracy_workstream_trnfrom
ckpt_restore_dataloader_fix

Conversation

@dgourab-aws
Copy link
Copy Markdown
Collaborator

No description provided.

@dgourab-aws
Copy link
Copy Markdown
Collaborator Author

Tested it locally using: pytest seed_test.py

Copy link
Copy Markdown
Collaborator

@HahTK HahTK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the stuff we need is missing.

  1. Most features are missing (setting the seed, parameterizing fuji.py, all features from GPU)
  2. We are only have GPU training scripts for TRN repo? Where is the TRN script?
  3. Likely completely untested. Did we run a TRN job?

Comment thread run_trainer.sh

# export JAX_PLATFORMS=cpu

#Perf Tuning Guideline here : https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/docs/PGLE.md
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are GPU flags in TRN runs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just the checkpointing script, I did not do any cleanup here.

Comment thread run_trainer.sh
###export NCCL_DEBUG_SUBSYS=COLL

#HAH quick fix
export XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=${HLO_DUMP_PATH} --xla_dump_hlo_pass_re='.*' --xla_dump_hlo_as_proto --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_multi_streamed_windowed_einsum=true --xla_gpu_enable_custom_fusions=true" # --xla_gpu_enable_address_computation_fusion=true"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are GPU flags in TRN runs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just the checkpointing script.
@HahTK I need a walkthrough of this script to actually clean it up. I havent used this to launch Trn jobs, I was using Apoorv's launch script.

Comment thread run_trainer.sh
echo "ERROR : ${TEST_SETUP} for ${N_EXPECTED_NODES} was launched with ${num_nodes}"
exit 1
fi
MESH_SELECTOR="gpu-${num_nodes}node-baseline"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again everything is GPU here

Comment thread run_trainer.sh
@@ -0,0 +1,150 @@
#!/usr/bin/env bash
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be fixed. This is a GPU script used to run TRN. We need to use the TRN script and just add ckpt resume to it.

import importlib


class SeedTest(test_utils.TestCase):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we actually set the seed?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The seed has to be set as an environment variable from any launch script like this:
export DATA_SEED=42
The launch script has not been added to the PR.

@HahTK
Copy link
Copy Markdown
Collaborator

HahTK commented Dec 27, 2024

Also the branch was created from the wrong commit id. It should have been
33ec152

but it seems to be branched from this instead
c20387c

@dgourab-aws
Copy link
Copy Markdown
Collaborator Author

Also the branch was created from the wrong commit id. It should have been 33ec152

but it seems to be branched from this instead c20387c

The GPU branch was created from 33ec152, the TRN branch was to be created from the AXLearn upstream branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants