diff --git a/accelerated-model-architectures b/accelerated-model-architectures index e99632015..7f09c3228 160000 --- a/accelerated-model-architectures +++ b/accelerated-model-architectures @@ -1 +1 @@ -Subproject commit e996320150a371c2afcb4d1ad405ce0b0eb811fa +Subproject commit 7f09c32284b5e2365975716e72e6d0a2ca60222e diff --git a/lm_engine/data/utils.py b/lm_engine/data/utils.py index 8ba4f4925..b2ecce8c0 100644 --- a/lm_engine/data/utils.py +++ b/lm_engine/data/utils.py @@ -141,4 +141,9 @@ def get_next_batch(x: Iterable | None) -> dict: if x is None: return None - return next(x) + batch = next(x) + + if Accelerator.get_accelerator() == Accelerator.trainium: + batch = {k: v.to(torch.int32) if v.dtype in [torch.int32, torch.int64] else v for k, v in batch.items()} + + return batch diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index d753840ff..1b8f37f0a 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -153,11 +153,22 @@ def wrap_model_container_for_distributed_training( if torch_compile: log_rank_0(logging.INFO, "using torch compile") + device = Accelerator.get_current_device() + if fsdp_algorithm is None: for i, model in enumerate(model_container): - model = model.to(Accelerator.get_current_device()) + if efficient_initialization: + model = model.to_empty(device=device) + + for module in model.modules(): + if hasattr(module, "reset_parameters"): + with torch.device(device): + module.reset_parameters() + else: + model = model.to(device) + if torch_compile: - model = torch.compile(model) + model = torch.compile(model, backend=Accelerator.get_torch_compile_backend()) model_container[i] = model @@ -303,8 +314,6 @@ def _sharding_function(parameter: nn.Parameter) -> Shard: ) if efficient_initialization: - device = Accelerator.get_current_device() - # contributed by Yu Chin Fabian Lim # original reference https://github.com/fabianlim/accelerate/pull/1 if model_name is None: @@ -359,7 +368,7 @@ def _sharding_function(parameter: nn.Parameter) -> Shard: cpu_offload=CPUOffload(offload_params=True) if cpu_offload else None, mixed_precision=mixed_precision_policy, auto_wrap_policy=partial(transformer_auto_wrap_policy, transformer_layer_cls=block_classes), - device_id=Accelerator.get_current_device(), + device_id=device, limit_all_gathers=True, use_orig_params=True, # https://github.com/meta-llama/llama-recipes/blob/492455dc080f6c25f356e283e443be0cce86aaeb/src/llama_recipes/finetuning.py#L191 @@ -372,7 +381,7 @@ def _sharding_function(parameter: nn.Parameter) -> Shard: if torch_compile: for i, model in enumerate(model_container): - model_container[i] = torch.compile(model) + model_container[i] = torch.compile(model, backend=Accelerator.get_torch_compile_backend()) set_parameter_marker_maps( model_container, @@ -425,7 +434,7 @@ def _sharding_function(parameter: nn.Parameter) -> Shard: model, stage_index=model.pipeline_stage_id, num_stages=num_pipeline_stages, - device=Accelerator.get_current_device(), + device=device, input_args=dummy_input_tensor, output_args=dummy_output_tensor, group=ProcessGroupManager.get_pipeline_parallel_group(), diff --git a/lm_engine/finetune.py b/lm_engine/finetune.py index a266fd66b..9a9827526 100644 --- a/lm_engine/finetune.py +++ b/lm_engine/finetune.py @@ -157,7 +157,12 @@ def evaluate( else: num_steps = 0 - num_steps = torch.tensor(num_steps, device=Accelerator.get_current_device(), dtype=torch.long) + num_steps = torch.tensor( + num_steps, + device=Accelerator.get_current_device(), + dtype=torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.long, + ) + torch.distributed.all_reduce(num_steps, group=ProcessGroupManager.get_tensor_parallel_group()) num_steps = num_steps.item() else: diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 6ca436d19..17606face 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -158,6 +158,9 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: ) self.position_embedding_type = config.position_embedding_type + self.use_rope = self.position_embedding_type == "rope" + self.use_learned_absolute = self.position_embedding_type == "learned_absolute" + self._setup_positional_encoding() def forward( @@ -206,10 +209,18 @@ def forward( ) query_length = key_length - past_length - position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=hidden_states.device) + position_ids = torch.arange( + past_length, + key_length, + dtype=torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.long, + device=hidden_states.device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, query_length) - rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) + rope_cos_sin = ( + self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) if self.use_rope else None + ) if is_generation_cache_enabled() and use_cache and cache_params is None: cache_params = GenerationCache() @@ -246,12 +257,24 @@ def _get_position_ids( ) -> torch.Tensor: if attention_mask is not None and len(attention_mask.shape) == 2: # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids = ( + attention_mask.to( + torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.int64 + ).cumsum(-1) + - 1 + ) + position_ids.masked_fill_(attention_mask == 0, 0) if past_length > 0: position_ids = position_ids[:, past_length:key_length:] else: - position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=device) + position_ids = torch.arange( + past_length, + key_length, + dtype=torch.int32 if Accelerator.get_accelerator() == Accelerator.trainium else torch.int64, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, query_length) return position_ids @@ -259,11 +282,10 @@ def _get_position_ids( def _get_rope_cos_sin( self, key_length: int, position_ids: torch.Tensor, dtype: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - if self.position_embedding_type == "rope": - cos, sin = self.rope(key_length, dtype=dtype) - cos = cos[position_ids] - sin = sin[position_ids] - return cos, sin + cos, sin = self.rope(key_length, dtype=dtype) + cos = cos[position_ids] + sin = sin[position_ids] + return cos, sin def _prepare_causal_attention_mask( self, @@ -306,19 +328,6 @@ def _prepare_causal_attention_mask( return causal_mask - def _get_initial_hidden_state(self, input_ids: torch.Tensor, position_ids: torch.Tensor | None) -> torch.Tensor: - hidden_state = self.wte(input_ids) - - if self.position_embedding_type == "learned_absolute": - hidden_state = hidden_state + self.wpe(position_ids) - - hidden_state = self.embedding_dropout(hidden_state) - - if self.m_emb is not None: - hidden_state = hidden_state * self.m_emb - - return hidden_state - def _prepare_a_bunch_of_stuff( self, input_ids: torch.Tensor | None = None, @@ -358,14 +367,25 @@ def _prepare_a_bunch_of_stuff( query_length = input_shape[-1] key_length = past_length + query_length - if position_ids is None: - position_ids = self._get_position_ids( - attention_mask, past_length, query_length, key_length, input_ids.device - ) + hidden_states = self.wte(input_ids) - hidden_states = self._get_initial_hidden_state(input_ids, position_ids) + if self.use_rope or self.use_learned_absolute: + if position_ids is None: + position_ids = self._get_position_ids( + attention_mask, past_length, query_length, key_length, input_ids.device + ) - rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) + if self.use_learned_absolute: + hidden_states = hidden_states + self.wpe(position_ids) + + hidden_states = self.embedding_dropout(hidden_states) + + if self.m_emb is not None: + hidden_states = hidden_states * self.m_emb + + rope_cos_sin = ( + self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) if self.use_rope else None + ) attention_mask = self._get_maybe_causal_mask( attention_mask, batch_size, query_length, key_length, hidden_states.dtype, input_ids.device diff --git a/lm_engine/pretrain.py b/lm_engine/pretrain.py index 65f499b18..4136f5c56 100644 --- a/lm_engine/pretrain.py +++ b/lm_engine/pretrain.py @@ -40,6 +40,7 @@ log_rank_0, set_seed, setup_tf32, + string_to_torch_dtype, ) @@ -692,7 +693,17 @@ def main(args_class: type[DistillationArgs | TrainingArgs] = TrainingArgs) -> No experiments_tracker.log_args(args, **model_container[0].calculate_num_parameters(return_dict=True)) # main training loop - with disable_generation_cache(), enable_kernels(args.kernel_args.kernels): + with ( + disable_generation_cache(), + enable_kernels(args.kernel_args.kernels), + ( + torch.autocast( + device_type=Accelerator.get_device_type(), dtype=string_to_torch_dtype(args.mixed_precision_args.dtype) + ) + if args.distributed_args.fsdp_algorithm is None + else nullcontext + ), + ): train( args, model_container=model_container, diff --git a/lm_engine/utils/accelerator.py b/lm_engine/utils/accelerator.py index f0030ae1e..c7830954a 100644 --- a/lm_engine/utils/accelerator.py +++ b/lm_engine/utils/accelerator.py @@ -9,6 +9,7 @@ from typing import Any import torch +from torch.profiler import ProfilerActivity from .packages import is_torch_neuronx_available, is_torch_xla_available @@ -127,3 +128,21 @@ def set_rng_state(state: Any) -> Any: raise ValueError(f"unexpected device ({accelerator})") return state + + @staticmethod + def get_profiler_activity() -> ProfilerActivity: + accelerator = Accelerator.get_accelerator() + + if accelerator == Accelerator.trainium: + return ProfilerActivity.PrivateUse1 + + return ProfilerActivity.CUDA + + @staticmethod + def get_torch_compile_backend() -> str: + accelerator = Accelerator.get_accelerator() + + if accelerator == Accelerator.trainium: + return "neuron" + + return "inductor" diff --git a/lm_engine/utils/profiler.py b/lm_engine/utils/profiler.py index daf20200c..f9fd73c7f 100644 --- a/lm_engine/utils/profiler.py +++ b/lm_engine/utils/profiler.py @@ -7,7 +7,7 @@ import torch from .accelerator import Accelerator -from .packages import is_torch_xla_available +from .packages import is_torch_neuronx_available, is_torch_xla_available from .parallel import ProcessGroupManager @@ -16,6 +16,10 @@ from torch_xla.debug.profiler import stop_trace as xla_stop_trace +if is_torch_neuronx_available(): + from torch_neuronx.profiling import NeuronConfig, NeuronProfiler, ProfileMode + + class TorchProfiler: def __init__(self, path: str | None, wait: int = 5, active: int = 1, warmup: int = 5) -> TorchProfiler: self.path = path @@ -30,17 +34,33 @@ def __init__(self, path: str | None, wait: int = 5, active: int = 1, warmup: int self.accelerator = Accelerator.get_accelerator() self._step = 0 + experimental_config = None + if self.accelerator == Accelerator.trainium: + experimental_config = NeuronConfig( + modes=[ProfileMode.DEVICE, ProfileMode.RUNTIME], + max_events_per_nc=100000, + profile_output_dir=path, + capture_enabled_for_nc="0", + ) + + exporter = NeuronProfiler(experimental_config) + self._profiler = None if self.accelerator != Accelerator.tpu: self._profiler = torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + activities=[torch.profiler.ProfilerActivity.CPU, Accelerator.get_profiler_activity()], schedule=torch.profiler.schedule( wait=wait if ProcessGroupManager.get_global_rank() == 0 else 150000, warmup=warmup, active=active, repeat=1, ), - on_trace_ready=torch.profiler.tensorboard_trace_handler(path), + experimental_config=experimental_config, + on_trace_ready=( + exporter.export_trace + if self.accelerator == Accelerator.trainium + else torch.profiler.tensorboard_trace_handler(path) + ), record_shapes=True, profile_memory=True, ) diff --git a/scripts/aws-trainium/explorer-local.sh b/scripts/aws-trainium/explorer-local.sh new file mode 100644 index 000000000..a527d5992 --- /dev/null +++ b/scripts/aws-trainium/explorer-local.sh @@ -0,0 +1 @@ +ssh -i ~/Desktop/mayank-melbourne.pem -L 8001:localhost:3001 -L 8002:localhost:3002 trainium-melbourne -fN diff --git a/scripts/aws-trainium/explorer-remote.sh b/scripts/aws-trainium/explorer-remote.sh new file mode 100644 index 000000000..3ce4c37c1 --- /dev/null +++ b/scripts/aws-trainium/explorer-remote.sh @@ -0,0 +1 @@ +neuron-explorer view -v 2 --data-path ./parquet_files diff --git a/scripts/aws-trainium/pretrain.sh b/scripts/aws-trainium/pretrain.sh index 6a90f4259..0f81fad6c 100644 --- a/scripts/aws-trainium/pretrain.sh +++ b/scripts/aws-trainium/pretrain.sh @@ -1,7 +1,7 @@ TOKENIZERS_PARALLELISM=false \ torchrun --nnodes=1 \ --node_rank=0 \ - --nproc_per_node=2 \ + --nproc_per_node=4 \ --rdzv_id=101 \ -m lm_engine.pretrain \ --config ${1}