[TRAINIUM] improve support#421
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces comprehensive support for AWS Trainium accelerators, including specialized data type handling (int32), dynamic selection of compilation backends, and integration with the Neuron profiler. It also implements efficient model initialization and refactors positional embedding logic within the dense model mixins. Feedback focuses on correcting the usage of nullcontext, addressing potential RuntimeError and TypeError exceptions in the autocast and profiling logic due to platform-specific arguments, and improving the portability of utility scripts by removing hardcoded file paths.
| device_type=Accelerator.get_device_type(), dtype=string_to_torch_dtype(args.mixed_precision_args.dtype) | ||
| ) | ||
| if args.distributed_args.fsdp_algorithm is None | ||
| else nullcontext |
There was a problem hiding this comment.
nullcontext is a class and must be instantiated (i.e., nullcontext()) to be used as a context manager. Using the class itself in a with statement will raise a TypeError because the class does not implement the context manager protocol (__enter__/__exit__) as class methods.
| else nullcontext | |
| else nullcontext() |
| enable_kernels(args.kernel_args.kernels), | ||
| ( | ||
| torch.autocast( | ||
| device_type=Accelerator.get_device_type(), dtype=string_to_torch_dtype(args.mixed_precision_args.dtype) |
There was a problem hiding this comment.
torch.autocast does not natively support device_type='xla' or 'neuron' in standard PyTorch. This will cause a RuntimeError when running on TPU or Trainium if the environment does not have a specifically patched version of PyTorch. For these accelerators, it is generally recommended to use device_type='cpu' (which is how Neuron AMP is typically triggered) or the accelerator-specific autocast context (e.g., torch_xla.amp.autocast).
| repeat=1, | ||
| ), | ||
| on_trace_ready=torch.profiler.tensorboard_trace_handler(path), | ||
| experimental_config=experimental_config, |
There was a problem hiding this comment.
The experimental_config parameter is not part of the standard torch.profiler.profile signature in PyTorch. While it is supported by torch-neuronx, passing it (even as None) will cause a TypeError on standard PyTorch installations (e.g., when running on CUDA or CPU). To maintain cross-platform compatibility, consider using a conditional approach or dictionary unpacking to call profile without this argument on non-Trainium devices.
| @@ -0,0 +1 @@ | |||
| ssh -i ~/Desktop/mayank-melbourne.pem -L 8001:localhost:3001 -L 8002:localhost:3002 trainium-melbourne -fN | |||
There was a problem hiding this comment.
The script contains a hardcoded absolute path to a personal PEM file (~/Desktop/mayank-melbourne.pem). This makes the script non-portable and potentially exposes details about your local file system. It is recommended to use an environment variable or a generic placeholder.
| ssh -i ~/Desktop/mayank-melbourne.pem -L 8001:localhost:3001 -L 8002:localhost:3002 trainium-melbourne -fN | |
| ssh -i ${PEM_FILE:-/path/to/your/key.pem} -L 8001:localhost:3001 -L 8002:localhost:3002 ${REMOTE_HOST:-trainium-melbourne} -fN |
No description provided.