Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion lm_engine/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,9 @@ def get_next_batch(x: Iterable | None) -> dict:
if x is None:
return None

return next(x)
batch = next(x)

if Accelerator.get_accelerator() == Accelerator.trainium:
batch = {k: v.to(torch.int32) if v.dtype in [torch.int32, torch.int64] else v for k, v in batch.items()}

return batch
23 changes: 16 additions & 7 deletions lm_engine/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,22 @@ def wrap_model_container_for_distributed_training(
if torch_compile:
log_rank_0(logging.INFO, "using torch compile")

device = Accelerator.get_current_device()

if fsdp_algorithm is None:
for i, model in enumerate(model_container):
model = model.to(Accelerator.get_current_device())
if efficient_initialization:
model = model.to_empty(device=device)

for module in model.modules():
if hasattr(module, "reset_parameters"):
with torch.device(device):
module.reset_parameters()
else:
model = model.to(device)

if torch_compile:
model = torch.compile(model)
model = torch.compile(model, backend=Accelerator.get_torch_compile_backend())

model_container[i] = model

Expand Down Expand Up @@ -303,8 +314,6 @@ def _sharding_function(parameter: nn.Parameter) -> Shard:
)

if efficient_initialization:
device = Accelerator.get_current_device()

# contributed by Yu Chin Fabian Lim
# original reference https://github.com/fabianlim/accelerate/pull/1
if model_name is None:
Expand Down Expand Up @@ -359,7 +368,7 @@ def _sharding_function(parameter: nn.Parameter) -> Shard:
cpu_offload=CPUOffload(offload_params=True) if cpu_offload else None,
mixed_precision=mixed_precision_policy,
auto_wrap_policy=partial(transformer_auto_wrap_policy, transformer_layer_cls=block_classes),
device_id=Accelerator.get_current_device(),
device_id=device,
limit_all_gathers=True,
use_orig_params=True,
# https://github.com/meta-llama/llama-recipes/blob/492455dc080f6c25f356e283e443be0cce86aaeb/src/llama_recipes/finetuning.py#L191
Expand All @@ -372,7 +381,7 @@ def _sharding_function(parameter: nn.Parameter) -> Shard:

if torch_compile:
for i, model in enumerate(model_container):
model_container[i] = torch.compile(model)
model_container[i] = torch.compile(model, backend=Accelerator.get_torch_compile_backend())

set_parameter_marker_maps(
model_container,
Expand Down Expand Up @@ -425,7 +434,7 @@ def _sharding_function(parameter: nn.Parameter) -> Shard:
model,
stage_index=model.pipeline_stage_id,
num_stages=num_pipeline_stages,
device=Accelerator.get_current_device(),
device=device,
input_args=dummy_input_tensor,
output_args=dummy_output_tensor,
group=ProcessGroupManager.get_pipeline_parallel_group(),
Expand Down
7 changes: 6 additions & 1 deletion lm_engine/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,12 @@ def evaluate(
else:
num_steps = 0

num_steps = torch.tensor(num_steps, device=Accelerator.get_current_device(), dtype=torch.long)
num_steps = torch.tensor(
num_steps,
device=Accelerator.get_current_device(),
dtype=torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.long,
)

torch.distributed.all_reduce(num_steps, group=ProcessGroupManager.get_tensor_parallel_group())
num_steps = num_steps.item()
else:
Expand Down
76 changes: 48 additions & 28 deletions lm_engine/hf_models/mixins/dense/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None:
)

self.position_embedding_type = config.position_embedding_type
self.use_rope = self.position_embedding_type == "rope"
self.use_learned_absolute = self.position_embedding_type == "learned_absolute"

self._setup_positional_encoding()

def forward(
Expand Down Expand Up @@ -206,10 +209,18 @@ def forward(
)
query_length = key_length - past_length

position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=hidden_states.device)
position_ids = torch.arange(
past_length,
key_length,
dtype=torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.long,
device=hidden_states.device,
)

position_ids = position_ids.unsqueeze(0).view(-1, query_length)

rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype)
rope_cos_sin = (
self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) if self.use_rope else None
)

if is_generation_cache_enabled() and use_cache and cache_params is None:
cache_params = GenerationCache()
Expand Down Expand Up @@ -246,24 +257,35 @@ def _get_position_ids(
) -> torch.Tensor:
if attention_mask is not None and len(attention_mask.shape) == 2:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids = (
attention_mask.to(
torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.int64
).cumsum(-1)
- 1
)

position_ids.masked_fill_(attention_mask == 0, 0)
if past_length > 0:
position_ids = position_ids[:, past_length:key_length:]
else:
position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=device)
position_ids = torch.arange(
past_length,
key_length,
dtype=torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.int64,
device=device,
)

position_ids = position_ids.unsqueeze(0).view(-1, query_length)

return position_ids

def _get_rope_cos_sin(
self, key_length: int, position_ids: torch.Tensor, dtype: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
if self.position_embedding_type == "rope":
cos, sin = self.rope(key_length, dtype=dtype)
cos = cos[position_ids]
sin = sin[position_ids]
return cos, sin
cos, sin = self.rope(key_length, dtype=dtype)
cos = cos[position_ids]
sin = sin[position_ids]
return cos, sin

def _prepare_causal_attention_mask(
self,
Expand Down Expand Up @@ -306,19 +328,6 @@ def _prepare_causal_attention_mask(

return causal_mask

def _get_initial_hidden_state(self, input_ids: torch.Tensor, position_ids: torch.Tensor | None) -> torch.Tensor:
hidden_state = self.wte(input_ids)

if self.position_embedding_type == "learned_absolute":
hidden_state = hidden_state + self.wpe(position_ids)

hidden_state = self.embedding_dropout(hidden_state)

if self.m_emb is not None:
hidden_state = hidden_state * self.m_emb

return hidden_state

def _prepare_a_bunch_of_stuff(
self,
input_ids: torch.Tensor | None = None,
Expand Down Expand Up @@ -358,14 +367,25 @@ def _prepare_a_bunch_of_stuff(
query_length = input_shape[-1]
key_length = past_length + query_length

if position_ids is None:
position_ids = self._get_position_ids(
attention_mask, past_length, query_length, key_length, input_ids.device
)
hidden_states = self.wte(input_ids)

hidden_states = self._get_initial_hidden_state(input_ids, position_ids)
if self.use_rope or self.use_learned_absolute:
if position_ids is None:
position_ids = self._get_position_ids(
attention_mask, past_length, query_length, key_length, input_ids.device
)

rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype)
if self.use_learned_absolute:
hidden_states = hidden_states + self.wpe(position_ids)

hidden_states = self.embedding_dropout(hidden_states)

if self.m_emb is not None:
hidden_states = hidden_states * self.m_emb

rope_cos_sin = (
self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) if self.use_rope else None
)

attention_mask = self._get_maybe_causal_mask(
attention_mask, batch_size, query_length, key_length, hidden_states.dtype, input_ids.device
Expand Down
13 changes: 12 additions & 1 deletion lm_engine/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
log_rank_0,
set_seed,
setup_tf32,
string_to_torch_dtype,
)


Expand Down Expand Up @@ -692,7 +693,17 @@ def main(args_class: type[DistillationArgs | TrainingArgs] = TrainingArgs) -> No
experiments_tracker.log_args(args, **model_container[0].calculate_num_parameters(return_dict=True))

# main training loop
with disable_generation_cache(), enable_kernels(args.kernel_args.kernels):
with (
disable_generation_cache(),
enable_kernels(args.kernel_args.kernels),
(
torch.autocast(
device_type=Accelerator.get_device_type(), dtype=string_to_torch_dtype(args.mixed_precision_args.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

torch.autocast does not natively support device_type='xla' or 'neuron' in standard PyTorch. This will cause a RuntimeError when running on TPU or Trainium if the environment does not have a specifically patched version of PyTorch. For these accelerators, it is generally recommended to use device_type='cpu' (which is how Neuron AMP is typically triggered) or the accelerator-specific autocast context (e.g., torch_xla.amp.autocast).

)
if args.distributed_args.fsdp_algorithm is None
else nullcontext
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

nullcontext is a class and must be instantiated (i.e., nullcontext()) to be used as a context manager. Using the class itself in a with statement will raise a TypeError because the class does not implement the context manager protocol (__enter__/__exit__) as class methods.

Suggested change
else nullcontext
else nullcontext()

),
):
train(
args,
model_container=model_container,
Expand Down
19 changes: 19 additions & 0 deletions lm_engine/utils/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any

import torch
from torch.profiler import ProfilerActivity

from .packages import is_torch_neuronx_available, is_torch_xla_available

Expand Down Expand Up @@ -127,3 +128,21 @@ def set_rng_state(state: Any) -> Any:
raise ValueError(f"unexpected device ({accelerator})")

return state

@staticmethod
def get_profiler_activity() -> ProfilerActivity:
accelerator = Accelerator.get_accelerator()

if accelerator == Accelerator.trainium:
return ProfilerActivity.PrivateUse1

return ProfilerActivity.CUDA

@staticmethod
def get_torch_compile_backend() -> str:
accelerator = Accelerator.get_accelerator()

if accelerator == Accelerator.trainium:
return "neuron"

return "inductor"
26 changes: 23 additions & 3 deletions lm_engine/utils/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from .accelerator import Accelerator
from .packages import is_torch_xla_available
from .packages import is_torch_neuronx_available, is_torch_xla_available
from .parallel import ProcessGroupManager


Expand All @@ -16,6 +16,10 @@
from torch_xla.debug.profiler import stop_trace as xla_stop_trace


if is_torch_neuronx_available():
from torch_neuronx.profiling import NeuronConfig, NeuronProfiler, ProfileMode


class TorchProfiler:
def __init__(self, path: str | None, wait: int = 5, active: int = 1, warmup: int = 5) -> TorchProfiler:
self.path = path
Expand All @@ -30,17 +34,33 @@ def __init__(self, path: str | None, wait: int = 5, active: int = 1, warmup: int
self.accelerator = Accelerator.get_accelerator()
self._step = 0

experimental_config = None
if self.accelerator == Accelerator.trainium:
experimental_config = NeuronConfig(
modes=[ProfileMode.DEVICE, ProfileMode.RUNTIME],
max_events_per_nc=100000,
profile_output_dir=path,
capture_enabled_for_nc="0",
)

exporter = NeuronProfiler(experimental_config)

self._profiler = None
if self.accelerator != Accelerator.tpu:
self._profiler = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
activities=[torch.profiler.ProfilerActivity.CPU, Accelerator.get_profiler_activity()],
schedule=torch.profiler.schedule(
wait=wait if ProcessGroupManager.get_global_rank() == 0 else 150000,
warmup=warmup,
active=active,
repeat=1,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(path),
experimental_config=experimental_config,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The experimental_config parameter is not part of the standard torch.profiler.profile signature in PyTorch. While it is supported by torch-neuronx, passing it (even as None) will cause a TypeError on standard PyTorch installations (e.g., when running on CUDA or CPU). To maintain cross-platform compatibility, consider using a conditional approach or dictionary unpacking to call profile without this argument on non-Trainium devices.

on_trace_ready=(
exporter.export_trace
if self.accelerator == Accelerator.trainium
else torch.profiler.tensorboard_trace_handler(path)
),
record_shapes=True,
profile_memory=True,
)
Expand Down
1 change: 1 addition & 0 deletions scripts/aws-trainium/explorer-local.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ssh -i ~/Desktop/mayank-melbourne.pem -L 8001:localhost:3001 -L 8002:localhost:3002 trainium-melbourne -fN
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The script contains a hardcoded absolute path to a personal PEM file (~/Desktop/mayank-melbourne.pem). This makes the script non-portable and potentially exposes details about your local file system. It is recommended to use an environment variable or a generic placeholder.

Suggested change
ssh -i ~/Desktop/mayank-melbourne.pem -L 8001:localhost:3001 -L 8002:localhost:3002 trainium-melbourne -fN
ssh -i ${PEM_FILE:-/path/to/your/key.pem} -L 8001:localhost:3001 -L 8002:localhost:3002 ${REMOTE_HOST:-trainium-melbourne} -fN

1 change: 1 addition & 0 deletions scripts/aws-trainium/explorer-remote.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
neuron-explorer view -v 2 --data-path ./parquet_files
2 changes: 1 addition & 1 deletion scripts/aws-trainium/pretrain.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
TOKENIZERS_PARALLELISM=false \
torchrun --nnodes=1 \
--node_rank=0 \
--nproc_per_node=2 \
--nproc_per_node=4 \
--rdzv_id=101 \
-m lm_engine.pretrain \
--config ${1}