From 4f90813e285f49ed23bae96fddd6c0c4cda4a482 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 1 Feb 2026 01:01:08 -0800 Subject: [PATCH 01/28] init Signed-off-by: Mayank Mishra --- accelerated-model-architectures | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accelerated-model-architectures b/accelerated-model-architectures index 4d5cd9654..6432eda29 160000 --- a/accelerated-model-architectures +++ b/accelerated-model-architectures @@ -1 +1 @@ -Subproject commit 4d5cd96546cf62efa3b68d7cf44ad640b13e7ad1 +Subproject commit 6432eda2934e68183f8d7965b04e151e72314fdd From 0fb830ef2b0612949f2a06cea1c3ef545ff7d345 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 18:31:36 -0700 Subject: [PATCH 02/28] drop muon Signed-off-by: Mayank Mishra --- accelerated-model-architectures | 2 +- lm_engine/hf_models/mixins/dense/base.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/accelerated-model-architectures b/accelerated-model-architectures index 6432eda29..e99632015 160000 --- a/accelerated-model-architectures +++ b/accelerated-model-architectures @@ -1 +1 @@ -Subproject commit 6432eda2934e68183f8d7965b04e151e72314fdd +Subproject commit e996320150a371c2afcb4d1ad405ce0b0eb811fa diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index a16a4673d..06be904b7 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -20,6 +20,11 @@ class PreTrainedModelMixin(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + config_class = None layer_class = Block base_model_prefix = "transformer" @@ -186,7 +191,13 @@ def _get_position_ids( if past_length > 0: position_ids = position_ids[:, past_length:key_length:] else: - position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=device) + position_ids = torch.arange( + past_length, + key_length, + dtype=torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, query_length) return position_ids From 96a4b6881868d810dd3a1767d35d6aebc1bb3591 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 20:58:30 -0700 Subject: [PATCH 03/28] drop muon Signed-off-by: Mayank Mishra --- scripts/aws-trainium/explorer-remote.sh | 1 + 1 file changed, 1 insertion(+) create mode 100644 scripts/aws-trainium/explorer-remote.sh diff --git a/scripts/aws-trainium/explorer-remote.sh b/scripts/aws-trainium/explorer-remote.sh new file mode 100644 index 000000000..3ce4c37c1 --- /dev/null +++ b/scripts/aws-trainium/explorer-remote.sh @@ -0,0 +1 @@ +neuron-explorer view -v 2 --data-path ./parquet_files From badd3bca9fefc69c1f5ff1d80a2d22c694825e40 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 23:30:05 -0700 Subject: [PATCH 04/28] drop muon Signed-off-by: Mayank Mishra --- scripts/aws-trainium/explorer-local.sh | 1 + 1 file changed, 1 insertion(+) create mode 100644 scripts/aws-trainium/explorer-local.sh diff --git a/scripts/aws-trainium/explorer-local.sh b/scripts/aws-trainium/explorer-local.sh new file mode 100644 index 000000000..2d72c0239 --- /dev/null +++ b/scripts/aws-trainium/explorer-local.sh @@ -0,0 +1 @@ +ssh -i ~/Desktop/mayank-melbourne.pem -L 8001:localhost:3001 -L 8002:localhost:3002 ubuntu@ec2-16-50-57-175.ap-southeast-4.compute.amazonaws.com -fN From 15d10c56c21192accdaa226d892cbe4725eb2e5b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 15 Apr 2026 23:31:04 -0700 Subject: [PATCH 05/28] drop muon Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 6cc2d280a..6ca436d19 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -21,11 +21,6 @@ class PreTrainedModelMixin(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - config_class = None layer_class = Block base_model_prefix = "transformer" @@ -256,13 +251,7 @@ def _get_position_ids( if past_length > 0: position_ids = position_ids[:, past_length:key_length:] else: - position_ids = torch.arange( - past_length, - key_length, - dtype=torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.long, - device=device, - ) - + position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, query_length) return position_ids From b0d6d489d99ad9574c0ff4da216fb499b07480b7 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sat, 25 Apr 2026 23:26:02 -0700 Subject: [PATCH 06/28] init Signed-off-by: Mayank Mishra --- lm_engine/utils/accelerator.py | 10 ++++++++++ lm_engine/utils/profiler.py | 16 ++++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/lm_engine/utils/accelerator.py b/lm_engine/utils/accelerator.py index f0030ae1e..c6967a54b 100644 --- a/lm_engine/utils/accelerator.py +++ b/lm_engine/utils/accelerator.py @@ -9,6 +9,7 @@ from typing import Any import torch +from torch.profiler import ProfilerActivity from .packages import is_torch_neuronx_available, is_torch_xla_available @@ -127,3 +128,12 @@ def set_rng_state(state: Any) -> Any: raise ValueError(f"unexpected device ({accelerator})") return state + + @staticmethod + def get_profiler_activity(self) -> ProfilerActivity: + accelerator = Accelerator.get_accelerator() + + if accelerator == Accelerator.trainium: + return ProfilerActivity.PrivateUse1 + + return ProfilerActivity.CUDA diff --git a/lm_engine/utils/profiler.py b/lm_engine/utils/profiler.py index daf20200c..73cf3cfc7 100644 --- a/lm_engine/utils/profiler.py +++ b/lm_engine/utils/profiler.py @@ -7,7 +7,7 @@ import torch from .accelerator import Accelerator -from .packages import is_torch_xla_available +from .packages import is_torch_neuronx_available, is_torch_xla_available from .parallel import ProcessGroupManager @@ -16,6 +16,10 @@ from torch_xla.debug.profiler import stop_trace as xla_stop_trace +if is_torch_neuronx_available(): + from torch_neuronx.profiling import NeuronConfig, ProfileMode + + class TorchProfiler: def __init__(self, path: str | None, wait: int = 5, active: int = 1, warmup: int = 5) -> TorchProfiler: self.path = path @@ -30,16 +34,24 @@ def __init__(self, path: str | None, wait: int = 5, active: int = 1, warmup: int self.accelerator = Accelerator.get_accelerator() self._step = 0 + experimental_config = None + if self.accelerator == Accelerator.trainium: + experimental_config = NeuronConfig( + modes=[ProfileMode.DEVICE, ProfileMode.RUNTIME], + profile_output_dir=path, + ) + self._profiler = None if self.accelerator != Accelerator.tpu: self._profiler = torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + activities=[torch.profiler.ProfilerActivity.CPU, Accelerator.get_profiler_activity()], schedule=torch.profiler.schedule( wait=wait if ProcessGroupManager.get_global_rank() == 0 else 150000, warmup=warmup, active=active, repeat=1, ), + experimental_config=experimental_config, on_trace_ready=torch.profiler.tensorboard_trace_handler(path), record_shapes=True, profile_memory=True, From 2fe68ea532670b347f97ab119e0c9170fb5b9988 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sat, 25 Apr 2026 23:27:13 -0700 Subject: [PATCH 07/28] init Signed-off-by: Mayank Mishra --- lm_engine/utils/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/utils/accelerator.py b/lm_engine/utils/accelerator.py index c6967a54b..a3f5487cb 100644 --- a/lm_engine/utils/accelerator.py +++ b/lm_engine/utils/accelerator.py @@ -130,7 +130,7 @@ def set_rng_state(state: Any) -> Any: return state @staticmethod - def get_profiler_activity(self) -> ProfilerActivity: + def get_profiler_activity() -> ProfilerActivity: accelerator = Accelerator.get_accelerator() if accelerator == Accelerator.trainium: From 0afe9ab4f2182ec8334b2e84e670f062e618cbf1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sat, 25 Apr 2026 23:48:01 -0700 Subject: [PATCH 08/28] init Signed-off-by: Mayank Mishra --- .../modeling_utils/activations/glu.py | 5 +-- lm_engine/hf_models/modeling_utils/chunk.py | 39 ------------------- .../sequence_mixer_blocks/attention.py | 7 +--- 3 files changed, 3 insertions(+), 48 deletions(-) delete mode 100644 lm_engine/hf_models/modeling_utils/chunk.py diff --git a/lm_engine/hf_models/modeling_utils/activations/glu.py b/lm_engine/hf_models/modeling_utils/activations/glu.py index 95c05f59c..cc0105a96 100644 --- a/lm_engine/hf_models/modeling_utils/activations/glu.py +++ b/lm_engine/hf_models/modeling_utils/activations/glu.py @@ -10,7 +10,6 @@ from ....kernels import Kernel, is_kernel_allowed, wait_for_ACT from ....utils import Accelerator, is_xma_available -from ..chunk import contiguous_chunk from .base import get_base_activation @@ -53,9 +52,7 @@ def forward(self, x: torch.Tensor, is_interleaved: bool) -> torch.Tensor: u = x[..., 1::2] g = x[..., ::2] else: - u, g = (contiguous_chunk if Accelerator.get_accelerator() == Accelerator.trainium else torch.chunk)( - x, 2, dim=-1 - ) + u, g = torch.chunk(x, 2, dim=-1) x = u * self.base_activation(g) diff --git a/lm_engine/hf_models/modeling_utils/chunk.py b/lm_engine/hf_models/modeling_utils/chunk.py deleted file mode 100644 index 42e2641ac..000000000 --- a/lm_engine/hf_models/modeling_utils/chunk.py +++ /dev/null @@ -1,39 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -import torch - - -class _ContiguousChunk(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor, chunks: int, dim: int) -> tuple[torch.Tensor, ...]: - ctx.dim = dim - x = x.chunk(chunks, dim=dim) - return tuple(i.contiguous() for i in x) - - @staticmethod - def backward(ctx, *dy: tuple[torch.Tensor]) -> tuple[torch.Tensor, None, None]: - dy = tuple(i.contiguous() for i in dy) - return torch.cat(dy, dim=ctx.dim), None, None - - -def contiguous_chunk(x: torch.Tensor, chunks: int, dim: int = 0) -> tuple[torch.Tensor]: - return _ContiguousChunk.apply(x, chunks, dim) - - -class _ContiguousSplit(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor, split_size: int | tuple[int, ...], dim: int) -> tuple[torch.Tensor, ...]: - ctx.dim = dim - x = x.split(split_size, dim=dim) - return tuple(i.contiguous() for i in x) - - @staticmethod - def backward(ctx, *dy: tuple[torch.Tensor]) -> tuple[torch.Tensor, None, None]: - dy = tuple(i.contiguous() for i in dy) - return torch.cat(dy, dim=ctx.dim), None, None - - -def contiguous_split(x: torch.Tensor, split_size: tuple[int, ...], dim: int = 0) -> tuple[torch.Tensor]: - return _ContiguousSplit.apply(x, split_size, dim) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 80020f0af..a88d0fbd4 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -16,7 +16,6 @@ from ...config.sequence_mixer import ATTENTION_MULTIPLIER_INVERSE_METHOD, ATTENTION_MULTIPLIER_INVERSE_SQRT_METHOD from ...parameter import mark_parameter_as_mup_learning_rate from ..activations import sigmoid -from ..chunk import contiguous_split from ..dropout import Dropout from ..dtensor_module import DTensorModule from ..init_utils import _get_std_for_linear @@ -233,7 +232,7 @@ def forward( x = x.view(*input_shape) if self.attention_gate: - q, k, v, g = (contiguous_split if Accelerator.get_accelerator() == Accelerator.trainium else torch.split)( + q, k, v, g = torch.split( x, (self.num_groups * self.head_dim, self.head_dim, self.head_dim, self.num_groups * self.head_dim), dim=-1, @@ -241,9 +240,7 @@ def forward( g = g.reshape(*output_shape) else: - q, k, v = (contiguous_split if Accelerator.get_accelerator() == Accelerator.trainium else torch.split)( - x, (self.num_groups * self.head_dim, self.head_dim, self.head_dim), dim=-1 - ) + q, k, v = torch.split(x, (self.num_groups * self.head_dim, self.head_dim, self.head_dim), dim=-1) q = q.reshape(*output_shape) From 0b7b3df7a9c6e071d69619933ee5303a67b0a260 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sat, 25 Apr 2026 23:48:51 -0700 Subject: [PATCH 09/28] init Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/activations/glu.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/activations/glu.py b/lm_engine/hf_models/modeling_utils/activations/glu.py index cc0105a96..f34cba9a4 100644 --- a/lm_engine/hf_models/modeling_utils/activations/glu.py +++ b/lm_engine/hf_models/modeling_utils/activations/glu.py @@ -6,10 +6,9 @@ import torch import torch.nn as nn -import torch.nn.functional as F from ....kernels import Kernel, is_kernel_allowed, wait_for_ACT -from ....utils import Accelerator, is_xma_available +from ....utils import is_xma_available from .base import get_base_activation From 79dc7fdcf2683a615c9836eb1df465d32e2efc93 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 26 Apr 2026 01:08:36 -0700 Subject: [PATCH 10/28] init Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 6ca436d19..2c926a481 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -206,7 +206,12 @@ def forward( ) query_length = key_length - past_length - position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=hidden_states.device) + position_ids = torch.arange( + past_length, + key_length, + dtype=torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.long, + device=hidden_states.device, + ) position_ids = position_ids.unsqueeze(0).view(-1, query_length) rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) From 2e7cb5682abfd3a1ee506343c0047edcdc22b9e3 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 26 Apr 2026 01:17:14 -0700 Subject: [PATCH 11/28] init Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 2c926a481..46067ec47 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -206,12 +206,9 @@ def forward( ) query_length = key_length - past_length - position_ids = torch.arange( - past_length, - key_length, - dtype=torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.long, - device=hidden_states.device, - ) + dtype = torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.long + + position_ids = torch.arange(past_length, key_length, dtype=dtype, device=hidden_states.device) position_ids = position_ids.unsqueeze(0).view(-1, query_length) rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) From 3c6999d2b52a848182e02de980ad30f334c4fdde Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 26 Apr 2026 01:20:27 -0700 Subject: [PATCH 12/28] init Signed-off-by: Mayank Mishra --- lm_engine/finetune.py | 7 ++++++- lm_engine/hf_models/mixins/dense/base.py | 25 +++++++++++++++++++----- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/lm_engine/finetune.py b/lm_engine/finetune.py index a266fd66b..9a9827526 100644 --- a/lm_engine/finetune.py +++ b/lm_engine/finetune.py @@ -157,7 +157,12 @@ def evaluate( else: num_steps = 0 - num_steps = torch.tensor(num_steps, device=Accelerator.get_current_device(), dtype=torch.long) + num_steps = torch.tensor( + num_steps, + device=Accelerator.get_current_device(), + dtype=torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.long, + ) + torch.distributed.all_reduce(num_steps, group=ProcessGroupManager.get_tensor_parallel_group()) num_steps = num_steps.item() else: diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 46067ec47..915c04cd9 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -206,11 +206,14 @@ def forward( ) query_length = key_length - past_length - dtype = torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.long + position_ids = torch.arange( + past_length, + key_length, + dtype=torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.long, + device=hidden_states.device, + ) - position_ids = torch.arange(past_length, key_length, dtype=dtype, device=hidden_states.device) position_ids = position_ids.unsqueeze(0).view(-1, query_length) - rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) if is_generation_cache_enabled() and use_cache and cache_params is None: @@ -248,12 +251,24 @@ def _get_position_ids( ) -> torch.Tensor: if attention_mask is not None and len(attention_mask.shape) == 2: # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids = ( + attention_mask.to( + torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.int64 + ).cumsum(-1) + - 1 + ) + position_ids.masked_fill_(attention_mask == 0, 0) if past_length > 0: position_ids = position_ids[:, past_length:key_length:] else: - position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=device) + position_ids = torch.arange( + past_length, + key_length, + dtype=torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.int64, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, query_length) return position_ids From a6897f2bec0365a356798165fabca10c0ec44523 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 26 Apr 2026 01:21:00 -0700 Subject: [PATCH 13/28] init Signed-off-by: Mayank Mishra --- scripts/aws-trainium/pretrain.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/aws-trainium/pretrain.sh b/scripts/aws-trainium/pretrain.sh index 6a90f4259..0f81fad6c 100644 --- a/scripts/aws-trainium/pretrain.sh +++ b/scripts/aws-trainium/pretrain.sh @@ -1,7 +1,7 @@ TOKENIZERS_PARALLELISM=false \ torchrun --nnodes=1 \ --node_rank=0 \ - --nproc_per_node=2 \ + --nproc_per_node=4 \ --rdzv_id=101 \ -m lm_engine.pretrain \ --config ${1} From 70cb9f50bcb1eeb2aa91e3cc23c3cc7f6cb853ed Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 26 Apr 2026 01:23:21 -0700 Subject: [PATCH 14/28] init Signed-off-by: Mayank Mishra --- lm_engine/distributed.py | 2 +- lm_engine/utils/accelerator.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index d753840ff..15821bdfd 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -372,7 +372,7 @@ def _sharding_function(parameter: nn.Parameter) -> Shard: if torch_compile: for i, model in enumerate(model_container): - model_container[i] = torch.compile(model) + model_container[i] = torch.compile(model, backend=Accelerator.get_torch_compile_backend()) set_parameter_marker_maps( model_container, diff --git a/lm_engine/utils/accelerator.py b/lm_engine/utils/accelerator.py index a3f5487cb..c7830954a 100644 --- a/lm_engine/utils/accelerator.py +++ b/lm_engine/utils/accelerator.py @@ -137,3 +137,12 @@ def get_profiler_activity() -> ProfilerActivity: return ProfilerActivity.PrivateUse1 return ProfilerActivity.CUDA + + @staticmethod + def get_torch_compile_backend() -> str: + accelerator = Accelerator.get_accelerator() + + if accelerator == Accelerator.trainium: + return "neuron" + + return "inductor" From 13b11cb8259a84ae8cbe3fdfee97d0783ae05d4c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 26 Apr 2026 01:23:57 -0700 Subject: [PATCH 15/28] init Signed-off-by: Mayank Mishra --- lm_engine/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index 15821bdfd..78a32709a 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -157,7 +157,7 @@ def wrap_model_container_for_distributed_training( for i, model in enumerate(model_container): model = model.to(Accelerator.get_current_device()) if torch_compile: - model = torch.compile(model) + model = torch.compile(model, backend=Accelerator.get_torch_compile_backend()) model_container[i] = model From 141cc16d82387d114fa3cbdf1ecdb57a98717f9a Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 26 Apr 2026 01:44:16 -0700 Subject: [PATCH 16/28] init Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 53 +++++++++++++----------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 915c04cd9..17606face 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -158,6 +158,9 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: ) self.position_embedding_type = config.position_embedding_type + self.use_rope = self.position_embedding_type == "rope" + self.use_learned_absolute = self.position_embedding_type == "learned_absolute" + self._setup_positional_encoding() def forward( @@ -214,7 +217,10 @@ def forward( ) position_ids = position_ids.unsqueeze(0).view(-1, query_length) - rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) + + rope_cos_sin = ( + self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) if self.use_rope else None + ) if is_generation_cache_enabled() and use_cache and cache_params is None: cache_params = GenerationCache() @@ -276,11 +282,10 @@ def _get_position_ids( def _get_rope_cos_sin( self, key_length: int, position_ids: torch.Tensor, dtype: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - if self.position_embedding_type == "rope": - cos, sin = self.rope(key_length, dtype=dtype) - cos = cos[position_ids] - sin = sin[position_ids] - return cos, sin + cos, sin = self.rope(key_length, dtype=dtype) + cos = cos[position_ids] + sin = sin[position_ids] + return cos, sin def _prepare_causal_attention_mask( self, @@ -323,19 +328,6 @@ def _prepare_causal_attention_mask( return causal_mask - def _get_initial_hidden_state(self, input_ids: torch.Tensor, position_ids: torch.Tensor | None) -> torch.Tensor: - hidden_state = self.wte(input_ids) - - if self.position_embedding_type == "learned_absolute": - hidden_state = hidden_state + self.wpe(position_ids) - - hidden_state = self.embedding_dropout(hidden_state) - - if self.m_emb is not None: - hidden_state = hidden_state * self.m_emb - - return hidden_state - def _prepare_a_bunch_of_stuff( self, input_ids: torch.Tensor | None = None, @@ -375,14 +367,25 @@ def _prepare_a_bunch_of_stuff( query_length = input_shape[-1] key_length = past_length + query_length - if position_ids is None: - position_ids = self._get_position_ids( - attention_mask, past_length, query_length, key_length, input_ids.device - ) + hidden_states = self.wte(input_ids) + + if self.use_rope or self.use_learned_absolute: + if position_ids is None: + position_ids = self._get_position_ids( + attention_mask, past_length, query_length, key_length, input_ids.device + ) - hidden_states = self._get_initial_hidden_state(input_ids, position_ids) + if self.use_learned_absolute: + hidden_states = hidden_states + self.wpe(position_ids) - rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) + hidden_states = self.embedding_dropout(hidden_states) + + if self.m_emb is not None: + hidden_states = hidden_states * self.m_emb + + rope_cos_sin = ( + self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) if self.use_rope else None + ) attention_mask = self._get_maybe_causal_mask( attention_mask, batch_size, query_length, key_length, hidden_states.dtype, input_ids.device From 8d92cc9eb61e67b349474dd7e4e847e9b07e6c75 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 26 Apr 2026 02:00:59 -0700 Subject: [PATCH 17/28] Revert "init" This reverts commit 0b7b3df7a9c6e071d69619933ee5303a67b0a260. --- lm_engine/hf_models/modeling_utils/activations/glu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/activations/glu.py b/lm_engine/hf_models/modeling_utils/activations/glu.py index f34cba9a4..cc0105a96 100644 --- a/lm_engine/hf_models/modeling_utils/activations/glu.py +++ b/lm_engine/hf_models/modeling_utils/activations/glu.py @@ -6,9 +6,10 @@ import torch import torch.nn as nn +import torch.nn.functional as F from ....kernels import Kernel, is_kernel_allowed, wait_for_ACT -from ....utils import is_xma_available +from ....utils import Accelerator, is_xma_available from .base import get_base_activation From b76b5f1d6d7799fa62c1fa25651fa211d211ac24 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 26 Apr 2026 02:01:04 -0700 Subject: [PATCH 18/28] Revert "init" This reverts commit 0afe9ab4f2182ec8334b2e84e670f062e618cbf1. --- .../modeling_utils/activations/glu.py | 5 ++- lm_engine/hf_models/modeling_utils/chunk.py | 39 +++++++++++++++++++ .../sequence_mixer_blocks/attention.py | 7 +++- 3 files changed, 48 insertions(+), 3 deletions(-) create mode 100644 lm_engine/hf_models/modeling_utils/chunk.py diff --git a/lm_engine/hf_models/modeling_utils/activations/glu.py b/lm_engine/hf_models/modeling_utils/activations/glu.py index cc0105a96..95c05f59c 100644 --- a/lm_engine/hf_models/modeling_utils/activations/glu.py +++ b/lm_engine/hf_models/modeling_utils/activations/glu.py @@ -10,6 +10,7 @@ from ....kernels import Kernel, is_kernel_allowed, wait_for_ACT from ....utils import Accelerator, is_xma_available +from ..chunk import contiguous_chunk from .base import get_base_activation @@ -52,7 +53,9 @@ def forward(self, x: torch.Tensor, is_interleaved: bool) -> torch.Tensor: u = x[..., 1::2] g = x[..., ::2] else: - u, g = torch.chunk(x, 2, dim=-1) + u, g = (contiguous_chunk if Accelerator.get_accelerator() == Accelerator.trainium else torch.chunk)( + x, 2, dim=-1 + ) x = u * self.base_activation(g) diff --git a/lm_engine/hf_models/modeling_utils/chunk.py b/lm_engine/hf_models/modeling_utils/chunk.py new file mode 100644 index 000000000..42e2641ac --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/chunk.py @@ -0,0 +1,39 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +import torch + + +class _ContiguousChunk(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, chunks: int, dim: int) -> tuple[torch.Tensor, ...]: + ctx.dim = dim + x = x.chunk(chunks, dim=dim) + return tuple(i.contiguous() for i in x) + + @staticmethod + def backward(ctx, *dy: tuple[torch.Tensor]) -> tuple[torch.Tensor, None, None]: + dy = tuple(i.contiguous() for i in dy) + return torch.cat(dy, dim=ctx.dim), None, None + + +def contiguous_chunk(x: torch.Tensor, chunks: int, dim: int = 0) -> tuple[torch.Tensor]: + return _ContiguousChunk.apply(x, chunks, dim) + + +class _ContiguousSplit(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, split_size: int | tuple[int, ...], dim: int) -> tuple[torch.Tensor, ...]: + ctx.dim = dim + x = x.split(split_size, dim=dim) + return tuple(i.contiguous() for i in x) + + @staticmethod + def backward(ctx, *dy: tuple[torch.Tensor]) -> tuple[torch.Tensor, None, None]: + dy = tuple(i.contiguous() for i in dy) + return torch.cat(dy, dim=ctx.dim), None, None + + +def contiguous_split(x: torch.Tensor, split_size: tuple[int, ...], dim: int = 0) -> tuple[torch.Tensor]: + return _ContiguousSplit.apply(x, split_size, dim) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index a88d0fbd4..80020f0af 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -16,6 +16,7 @@ from ...config.sequence_mixer import ATTENTION_MULTIPLIER_INVERSE_METHOD, ATTENTION_MULTIPLIER_INVERSE_SQRT_METHOD from ...parameter import mark_parameter_as_mup_learning_rate from ..activations import sigmoid +from ..chunk import contiguous_split from ..dropout import Dropout from ..dtensor_module import DTensorModule from ..init_utils import _get_std_for_linear @@ -232,7 +233,7 @@ def forward( x = x.view(*input_shape) if self.attention_gate: - q, k, v, g = torch.split( + q, k, v, g = (contiguous_split if Accelerator.get_accelerator() == Accelerator.trainium else torch.split)( x, (self.num_groups * self.head_dim, self.head_dim, self.head_dim, self.num_groups * self.head_dim), dim=-1, @@ -240,7 +241,9 @@ def forward( g = g.reshape(*output_shape) else: - q, k, v = torch.split(x, (self.num_groups * self.head_dim, self.head_dim, self.head_dim), dim=-1) + q, k, v = (contiguous_split if Accelerator.get_accelerator() == Accelerator.trainium else torch.split)( + x, (self.num_groups * self.head_dim, self.head_dim, self.head_dim), dim=-1 + ) q = q.reshape(*output_shape) From 3a83e9704a5a7012a4e433db91588d02a247599a Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 26 Apr 2026 02:20:38 -0700 Subject: [PATCH 19/28] init Signed-off-by: Mayank Mishra --- lm_engine/data/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lm_engine/data/utils.py b/lm_engine/data/utils.py index 8ba4f4925..1fa1b0421 100644 --- a/lm_engine/data/utils.py +++ b/lm_engine/data/utils.py @@ -141,4 +141,7 @@ def get_next_batch(x: Iterable | None) -> dict: if x is None: return None + if Accelerator.get_accelerator() == Accelerator.trainium: + x = {k: v.to(torch.int32) if v.dtype in [torch.int32, torch.int64] else v for k, v in x.items()} + return next(x) From 35a73fa1d18e2bbb2fee0e7c38e2ae6771254b69 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 26 Apr 2026 02:21:57 -0700 Subject: [PATCH 20/28] init Signed-off-by: Mayank Mishra --- lm_engine/data/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lm_engine/data/utils.py b/lm_engine/data/utils.py index 1fa1b0421..b2ecce8c0 100644 --- a/lm_engine/data/utils.py +++ b/lm_engine/data/utils.py @@ -141,7 +141,9 @@ def get_next_batch(x: Iterable | None) -> dict: if x is None: return None + batch = next(x) + if Accelerator.get_accelerator() == Accelerator.trainium: - x = {k: v.to(torch.int32) if v.dtype in [torch.int32, torch.int64] else v for k, v in x.items()} + batch = {k: v.to(torch.int32) if v.dtype in [torch.int32, torch.int64] else v for k, v in batch.items()} - return next(x) + return batch From ed1511934b1be88234873762d7325f524b63b1fd Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 26 Apr 2026 12:01:00 -0700 Subject: [PATCH 21/28] init Signed-off-by: Mayank Mishra --- scripts/aws-trainium/explorer-local.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/aws-trainium/explorer-local.sh b/scripts/aws-trainium/explorer-local.sh index 2d72c0239..a527d5992 100644 --- a/scripts/aws-trainium/explorer-local.sh +++ b/scripts/aws-trainium/explorer-local.sh @@ -1 +1 @@ -ssh -i ~/Desktop/mayank-melbourne.pem -L 8001:localhost:3001 -L 8002:localhost:3002 ubuntu@ec2-16-50-57-175.ap-southeast-4.compute.amazonaws.com -fN +ssh -i ~/Desktop/mayank-melbourne.pem -L 8001:localhost:3001 -L 8002:localhost:3002 trainium-melbourne -fN From 61d21c81df8f31d213fbf057a0f76682ef7405e9 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 28 Apr 2026 23:03:58 -0700 Subject: [PATCH 22/28] fix Signed-off-by: Mayank Mishra --- lm_engine/utils/profiler.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/lm_engine/utils/profiler.py b/lm_engine/utils/profiler.py index 73cf3cfc7..f9fd73c7f 100644 --- a/lm_engine/utils/profiler.py +++ b/lm_engine/utils/profiler.py @@ -17,7 +17,7 @@ if is_torch_neuronx_available(): - from torch_neuronx.profiling import NeuronConfig, ProfileMode + from torch_neuronx.profiling import NeuronConfig, NeuronProfiler, ProfileMode class TorchProfiler: @@ -38,9 +38,13 @@ def __init__(self, path: str | None, wait: int = 5, active: int = 1, warmup: int if self.accelerator == Accelerator.trainium: experimental_config = NeuronConfig( modes=[ProfileMode.DEVICE, ProfileMode.RUNTIME], + max_events_per_nc=100000, profile_output_dir=path, + capture_enabled_for_nc="0", ) + exporter = NeuronProfiler(experimental_config) + self._profiler = None if self.accelerator != Accelerator.tpu: self._profiler = torch.profiler.profile( @@ -52,7 +56,11 @@ def __init__(self, path: str | None, wait: int = 5, active: int = 1, warmup: int repeat=1, ), experimental_config=experimental_config, - on_trace_ready=torch.profiler.tensorboard_trace_handler(path), + on_trace_ready=( + exporter.export_trace + if self.accelerator == Accelerator.trainium + else torch.profiler.tensorboard_trace_handler(path) + ), record_shapes=True, profile_memory=True, ) From 7dc9259860d3654106880d8db4f44c7ba8390dd2 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 29 Apr 2026 16:39:01 -0700 Subject: [PATCH 23/28] add Signed-off-by: Mayank Mishra --- lm_engine/distributed.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index 78a32709a..29485e525 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -155,7 +155,16 @@ def wrap_model_container_for_distributed_training( if fsdp_algorithm is None: for i, model in enumerate(model_container): - model = model.to(Accelerator.get_current_device()) + if efficient_initialization: + model = model.to_empty(Accelerator.get_current_device()) + + for module in model.modules(): + if hasattr(module, "reset_parameters"): + with torch.device(device): + module.reset_parameters() + else: + model = model.to(Accelerator.get_current_device()) + if torch_compile: model = torch.compile(model, backend=Accelerator.get_torch_compile_backend()) From 46ef7b5e4ed8b870b7e097c02c07cc563423cecd Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 29 Apr 2026 16:41:10 -0700 Subject: [PATCH 24/28] add Signed-off-by: Mayank Mishra --- accelerated-model-architectures | 2 +- lm_engine/distributed.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/accelerated-model-architectures b/accelerated-model-architectures index e99632015..7f09c3228 160000 --- a/accelerated-model-architectures +++ b/accelerated-model-architectures @@ -1 +1 @@ -Subproject commit e996320150a371c2afcb4d1ad405ce0b0eb811fa +Subproject commit 7f09c32284b5e2365975716e72e6d0a2ca60222e diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index 29485e525..1b8f37f0a 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -153,17 +153,19 @@ def wrap_model_container_for_distributed_training( if torch_compile: log_rank_0(logging.INFO, "using torch compile") + device = Accelerator.get_current_device() + if fsdp_algorithm is None: for i, model in enumerate(model_container): if efficient_initialization: - model = model.to_empty(Accelerator.get_current_device()) + model = model.to_empty(device=device) for module in model.modules(): if hasattr(module, "reset_parameters"): with torch.device(device): module.reset_parameters() else: - model = model.to(Accelerator.get_current_device()) + model = model.to(device) if torch_compile: model = torch.compile(model, backend=Accelerator.get_torch_compile_backend()) @@ -312,8 +314,6 @@ def _sharding_function(parameter: nn.Parameter) -> Shard: ) if efficient_initialization: - device = Accelerator.get_current_device() - # contributed by Yu Chin Fabian Lim # original reference https://github.com/fabianlim/accelerate/pull/1 if model_name is None: @@ -368,7 +368,7 @@ def _sharding_function(parameter: nn.Parameter) -> Shard: cpu_offload=CPUOffload(offload_params=True) if cpu_offload else None, mixed_precision=mixed_precision_policy, auto_wrap_policy=partial(transformer_auto_wrap_policy, transformer_layer_cls=block_classes), - device_id=Accelerator.get_current_device(), + device_id=device, limit_all_gathers=True, use_orig_params=True, # https://github.com/meta-llama/llama-recipes/blob/492455dc080f6c25f356e283e443be0cce86aaeb/src/llama_recipes/finetuning.py#L191 @@ -434,7 +434,7 @@ def _sharding_function(parameter: nn.Parameter) -> Shard: model, stage_index=model.pipeline_stage_id, num_stages=num_pipeline_stages, - device=Accelerator.get_current_device(), + device=device, input_args=dummy_input_tensor, output_args=dummy_output_tensor, group=ProcessGroupManager.get_pipeline_parallel_group(), From ebc3eb8a70f481b12ab6f77b69d7663ba2b19dae Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 29 Apr 2026 16:46:25 -0700 Subject: [PATCH 25/28] add Signed-off-by: Mayank Mishra --- lm_engine/pretrain.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/lm_engine/pretrain.py b/lm_engine/pretrain.py index 65f499b18..f30ed877f 100644 --- a/lm_engine/pretrain.py +++ b/lm_engine/pretrain.py @@ -395,7 +395,16 @@ def train( / ProcessGroupManager.get_world_size() ) - forward_context = nullcontext + forward_context = ( + torch.autocast(Accelerator.get_device_type(), dtype=torch.bfloat16) + if args.distributed_args.fsdp_algorithm is None + else nullcontext + ) + + backward_context = [loss_parallel if ProcessGroupManager.is_tensor_parallel_enabled() else nullcontext] + if args.distributed_args.fsdp_algorithm is None: + backward_context.append(torch.autocast(Accelerator.get_device_type(), dtype=torch.bfloat16)) + backward_context = loss_parallel if ProcessGroupManager.is_tensor_parallel_enabled() else nullcontext torch_profiler = TorchProfiler(args.logging_args.torch_profiler_trace_path) From 00277ee3a50a984b6f5558bf64d120b6203f2dd9 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 29 Apr 2026 16:46:52 -0700 Subject: [PATCH 26/28] add Signed-off-by: Mayank Mishra --- lm_engine/pretrain.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lm_engine/pretrain.py b/lm_engine/pretrain.py index f30ed877f..22afb3f54 100644 --- a/lm_engine/pretrain.py +++ b/lm_engine/pretrain.py @@ -405,8 +405,6 @@ def train( if args.distributed_args.fsdp_algorithm is None: backward_context.append(torch.autocast(Accelerator.get_device_type(), dtype=torch.bfloat16)) - backward_context = loss_parallel if ProcessGroupManager.is_tensor_parallel_enabled() else nullcontext - torch_profiler = TorchProfiler(args.logging_args.torch_profiler_trace_path) torch_profiler.__enter__() From a2220253fe91ea04f7a67641c91147cd514a5fd3 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 29 Apr 2026 16:49:18 -0700 Subject: [PATCH 27/28] add Signed-off-by: Mayank Mishra --- lm_engine/pretrain.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/lm_engine/pretrain.py b/lm_engine/pretrain.py index 22afb3f54..4d775f0d7 100644 --- a/lm_engine/pretrain.py +++ b/lm_engine/pretrain.py @@ -395,15 +395,8 @@ def train( / ProcessGroupManager.get_world_size() ) - forward_context = ( - torch.autocast(Accelerator.get_device_type(), dtype=torch.bfloat16) - if args.distributed_args.fsdp_algorithm is None - else nullcontext - ) - - backward_context = [loss_parallel if ProcessGroupManager.is_tensor_parallel_enabled() else nullcontext] - if args.distributed_args.fsdp_algorithm is None: - backward_context.append(torch.autocast(Accelerator.get_device_type(), dtype=torch.bfloat16)) + forward_context = nullcontext + backward_context = loss_parallel if ProcessGroupManager.is_tensor_parallel_enabled() else nullcontext torch_profiler = TorchProfiler(args.logging_args.torch_profiler_trace_path) torch_profiler.__enter__() @@ -699,7 +692,15 @@ def main(args_class: type[DistillationArgs | TrainingArgs] = TrainingArgs) -> No experiments_tracker.log_args(args, **model_container[0].calculate_num_parameters(return_dict=True)) # main training loop - with disable_generation_cache(), enable_kernels(args.kernel_args.kernels): + with ( + disable_generation_cache(), + enable_kernels(args.kernel_args.kernels), + ( + torch.autocast(device_type=Accelerator.get_device_type(), dtype=torch.bfloat16) + if args.distributed_args.fsdp_algorithm is None + else nullcontext + ), + ): train( args, model_container=model_container, From 9bfaa06241ee0ee7fd3a0d0a02a20c93dadaeb2d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 29 Apr 2026 16:51:53 -0700 Subject: [PATCH 28/28] add Signed-off-by: Mayank Mishra --- lm_engine/pretrain.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lm_engine/pretrain.py b/lm_engine/pretrain.py index 4d775f0d7..4136f5c56 100644 --- a/lm_engine/pretrain.py +++ b/lm_engine/pretrain.py @@ -40,6 +40,7 @@ log_rank_0, set_seed, setup_tf32, + string_to_torch_dtype, ) @@ -696,7 +697,9 @@ def main(args_class: type[DistillationArgs | TrainingArgs] = TrainingArgs) -> No disable_generation_cache(), enable_kernels(args.kernel_args.kernels), ( - torch.autocast(device_type=Accelerator.get_device_type(), dtype=torch.bfloat16) + torch.autocast( + device_type=Accelerator.get_device_type(), dtype=string_to_torch_dtype(args.mixed_precision_args.dtype) + ) if args.distributed_args.fsdp_algorithm is None else nullcontext ),