-
Notifications
You must be signed in to change notification settings - Fork 29
[TRAINIUM] improve support #421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4f90813
0fb830e
53293af
96a4b68
badd3bc
15d10c5
b0d6d48
2fe68ea
0afe9ab
0b7b3df
79dc7fd
2e7cb56
3c6999d
a6897f2
70cb9f5
13b11cb
141cc16
8d92cc9
b76b5f1
3a83e97
35a73fa
ed15119
61d21c8
7dc9259
46ef7b5
ebc3eb8
00277ee
a222025
9bfaa06
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -40,6 +40,7 @@ | |||||
| log_rank_0, | ||||||
| set_seed, | ||||||
| setup_tf32, | ||||||
| string_to_torch_dtype, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
|
|
@@ -692,7 +693,17 @@ def main(args_class: type[DistillationArgs | TrainingArgs] = TrainingArgs) -> No | |||||
| experiments_tracker.log_args(args, **model_container[0].calculate_num_parameters(return_dict=True)) | ||||||
|
|
||||||
| # main training loop | ||||||
| with disable_generation_cache(), enable_kernels(args.kernel_args.kernels): | ||||||
| with ( | ||||||
| disable_generation_cache(), | ||||||
| enable_kernels(args.kernel_args.kernels), | ||||||
| ( | ||||||
| torch.autocast( | ||||||
| device_type=Accelerator.get_device_type(), dtype=string_to_torch_dtype(args.mixed_precision_args.dtype) | ||||||
| ) | ||||||
| if args.distributed_args.fsdp_algorithm is None | ||||||
| else nullcontext | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| ), | ||||||
| ): | ||||||
| train( | ||||||
| args, | ||||||
| model_container=model_container, | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,7 @@ | |
| import torch | ||
|
|
||
| from .accelerator import Accelerator | ||
| from .packages import is_torch_xla_available | ||
| from .packages import is_torch_neuronx_available, is_torch_xla_available | ||
| from .parallel import ProcessGroupManager | ||
|
|
||
|
|
||
|
|
@@ -16,6 +16,10 @@ | |
| from torch_xla.debug.profiler import stop_trace as xla_stop_trace | ||
|
|
||
|
|
||
| if is_torch_neuronx_available(): | ||
| from torch_neuronx.profiling import NeuronConfig, NeuronProfiler, ProfileMode | ||
|
|
||
|
|
||
| class TorchProfiler: | ||
| def __init__(self, path: str | None, wait: int = 5, active: int = 1, warmup: int = 5) -> TorchProfiler: | ||
| self.path = path | ||
|
|
@@ -30,17 +34,33 @@ def __init__(self, path: str | None, wait: int = 5, active: int = 1, warmup: int | |
| self.accelerator = Accelerator.get_accelerator() | ||
| self._step = 0 | ||
|
|
||
| experimental_config = None | ||
| if self.accelerator == Accelerator.trainium: | ||
| experimental_config = NeuronConfig( | ||
| modes=[ProfileMode.DEVICE, ProfileMode.RUNTIME], | ||
| max_events_per_nc=100000, | ||
| profile_output_dir=path, | ||
| capture_enabled_for_nc="0", | ||
| ) | ||
|
|
||
| exporter = NeuronProfiler(experimental_config) | ||
|
|
||
| self._profiler = None | ||
| if self.accelerator != Accelerator.tpu: | ||
| self._profiler = torch.profiler.profile( | ||
| activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], | ||
| activities=[torch.profiler.ProfilerActivity.CPU, Accelerator.get_profiler_activity()], | ||
| schedule=torch.profiler.schedule( | ||
| wait=wait if ProcessGroupManager.get_global_rank() == 0 else 150000, | ||
| warmup=warmup, | ||
| active=active, | ||
| repeat=1, | ||
| ), | ||
| on_trace_ready=torch.profiler.tensorboard_trace_handler(path), | ||
| experimental_config=experimental_config, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| on_trace_ready=( | ||
| exporter.export_trace | ||
| if self.accelerator == Accelerator.trainium | ||
| else torch.profiler.tensorboard_trace_handler(path) | ||
| ), | ||
| record_shapes=True, | ||
| profile_memory=True, | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1 @@ | ||||||
| ssh -i ~/Desktop/mayank-melbourne.pem -L 8001:localhost:3001 -L 8002:localhost:3002 trainium-melbourne -fN | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The script contains a hardcoded absolute path to a personal PEM file (
Suggested change
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| neuron-explorer view -v 2 --data-path ./parquet_files |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| TOKENIZERS_PARALLELISM=false \ | ||
| torchrun --nnodes=1 \ | ||
| --node_rank=0 \ | ||
| --nproc_per_node=2 \ | ||
| --nproc_per_node=4 \ | ||
| --rdzv_id=101 \ | ||
| -m lm_engine.pretrain \ | ||
| --config ${1} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.autocastdoes not natively supportdevice_type='xla'or'neuron'in standard PyTorch. This will cause aRuntimeErrorwhen running on TPU or Trainium if the environment does not have a specifically patched version of PyTorch. For these accelerators, it is generally recommended to usedevice_type='cpu'(which is how Neuron AMP is typically triggered) or the accelerator-specific autocast context (e.g.,torch_xla.amp.autocast).