From 1e80bbb5d33a18bf0665ee6d540d3b4097f5d34b Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Thu, 7 Aug 2025 00:16:27 +0200 Subject: [PATCH 01/60] feat: implemented stage FQN generation for pipeline parallelism --- src/modalities/models/parallelism/__init__.py | 0 .../parallelism/pipeline_parallelism.py | 88 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 src/modalities/models/parallelism/__init__.py create mode 100644 src/modalities/models/parallelism/pipeline_parallelism.py diff --git a/src/modalities/models/parallelism/__init__.py b/src/modalities/models/parallelism/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py new file mode 100644 index 000000000..e1d2233ba --- /dev/null +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -0,0 +1,88 @@ +# Some portions of this implementation are inspired and/or adapted +# from Meta's open-source project TorchTitan, +# licensed under the BSD 3-Clause License. + +import math +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.pipelining.schedules import PipelineScheduleSingle, get_schedule_class + +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees + + +class FQNsPerStageGenerator(ABC): + @abstractmethod + def generate_fqns_per_stage( + self, num_stages: int, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1 + ) -> list[list[str]]: + """ + Generate a list of fully qualified names (FQNs) for each pipeline stage. + + Args: + num_stages (int): Number of stages in the pipeline. + num_layers (int): Total number of layers in the model. + input_layer_equivalence (int): Determines to how many transformer layers + the input layer corresponds. Default is 1. + output_layer_equivalence (int): Determines to how many transformer layers + the output layer corresponds. Default is 1. + + Returns: + list[list[str]]: A list containing an FQN list for each stage. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + + +class PipelineFactory: + """Pipeline factory class to create pipelined models.""" + + @staticmethod + def create_pipeline_model( + num_layers: int, + fqns_per_stage_generator: FQNsPerStageGenerator, + device_mesh: DeviceMesh, + pp_schedule_name: str, + num_layers_per_stage: int, + input_layer_equivalence: Optional[int] = 1, + output_layer_equivalence: Optional[int] = 1, + ) -> torch.nn.Module: + device_mesh[ParallelismDegrees.PP.value] + pp_dims = device_mesh.size(ParallelismDegrees.PP.value) + schedule_class = get_schedule_class(pp_schedule_name) + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) + if not is_single_stage_schedule: + raise ValueError( + f"Unsupported pipeline schedule: {pp_schedule_name}. We only support single-stage schedules." + ) + + # calculate the number of stages + num_virtual_stages = math.ceil( + (num_layers + input_layer_equivalence + output_layer_equivalence) / num_layers_per_stage + ) + if num_virtual_stages % pp_dims != 0: + raise ValueError( + f"Number of virtual stages {num_virtual_stages} is not divisible by parallel dimensions {pp_dims}. " + f"For reference: {num_layers=} {input_layer_equivalence=} " + f"{output_layer_equivalence=} {num_layers_per_stage=}" + ) + + stages_per_rank = num_virtual_stages // pp_dims + if stages_per_rank != 1: + raise ValueError( + f"Stages per rank {stages_per_rank} must be 1 for single-stage schedules. " + f"Please adjust {num_layers_per_stage=} to ensure each PP rank has exactly one stage." + ) + + fqns_per_stage_generator.generate_fqns_per_stage( + num_stages=num_virtual_stages, + num_layers=num_layers, + input_layer_equivalence=input_layer_equivalence, + output_layer_equivalence=output_layer_equivalence, + ) + + @staticmethod + def create_gpt2_model_splitter(): + """Create a GPT-2 model splitter for pipeline parallelism.""" + pass From ed93d284d108a1008b0eb7dae102f3b4a98f4c46 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Thu, 7 Aug 2025 15:49:13 +0200 Subject: [PATCH 02/60] feat: added FQNs per stage calculation --- .../parallelism/pipeline_parallelism.py | 91 +++++++++++++++++-- 1 file changed, 83 insertions(+), 8 deletions(-) diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py index e1d2233ba..ac6be437c 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism.py +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -1,4 +1,4 @@ -# Some portions of this implementation are inspired and/or adapted +# Some portions of this implementation are inspired, adapted, or refactored # from Meta's open-source project TorchTitan, # licensed under the BSD 3-Clause License. @@ -14,27 +14,102 @@ class FQNsPerStageGenerator(ABC): - @abstractmethod def generate_fqns_per_stage( self, num_stages: int, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1 ) -> list[list[str]]: """ - Generate a list of fully qualified names (FQNs) for each pipeline stage. + Generate FQNs for each stage in a GPT-2 model. Args: num_stages (int): Number of stages in the pipeline. num_layers (int): Total number of layers in the model. - input_layer_equivalence (int): Determines to how many transformer layers - the input layer corresponds. Default is 1. - output_layer_equivalence (int): Determines to how many transformer layers - the output layer corresponds. Default is 1. + input_layer_equivalence (int): Number of layers corresponding to the input layer. + output_layer_equivalence (int): Number of layers corresponding to the output layer. Returns: - list[list[str]]: A list containing an FQN list for each stage. + list[list[str]]: A list containing FQNs for each stage. + """ + + # Potential split points for GPT-2 model with each potential split point + # listing the FQNs of the modules in that stage and the computational weight. + # The computational weight of the input and output modules are estimated + # based on the number of layers they correspond to. + potential_split_points = self._get_potential_split_points( + num_layers=num_layers, + input_layer_equivalence=input_layer_equivalence, + output_layer_equivalence=output_layer_equivalence, + ) + # Calculate the weight per stage based on the total weight and number of stages + weight_per_stage = math.ceil(sum(weight for _, weight in potential_split_points) / num_stages) + # pack the stages with the layers + next_split_point = 0 + module_names_per_stage: list[list[str]] = [] + for _ in range(num_stages): + stage_fqns = [] + stage_weight = 0 + while next_split_point < len(potential_split_points): + fqns, weight = potential_split_points[next_split_point] + if weight > weight_per_stage: + raise ValueError( + f"Weight of {weight} for {fqns} exceeds weight per stage {weight_per_stage}. " + "Please adjust the number of stages or the weight distribution." + ) + if stage_weight + weight > weight_per_stage: + break + stage_fqns.extend(fqns) + stage_weight += weight + next_split_point += 1 + module_names_per_stage.append(stage_fqns) + + return module_names_per_stage + + @abstractmethod + def _get_potential_split_points( + self, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1 + ) -> list[tuple[list[str], int]]: + """ + Returns a list of potential split points for the GPT-2 model. + + Args: + num_layers (int): Total number of layers in the model. + input_layer_equivalence (int): Number of layers corresponding to the input layer. + output_layer_equivalence (int): Number of layers corresponding to the output layer. + + Returns: + list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. """ raise NotImplementedError("This method should be implemented by subclasses.") +class GPT2LLMFQNsPerStageGenerator(FQNsPerStageGenerator): + def _get_potential_split_points( + self, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1 + ) -> list[tuple[list[str], int]]: + """ + Returns a list of potential split points for the GPT-2 model. + + Args: + num_layers (int): Total number of layers in the model. + input_layer_equivalence (int): Number of layers corresponding to the input layer. + output_layer_equivalence (int): Number of layers corresponding to the output layer. + + Returns: + list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. + """ + + # Potential split points for GPT-2 model with each potential split point + # listing the FQNs of the modules in that stage and the computational weight. + # The computational weight of the input and output modules are estimated + # based on the number of layers they correspond to. + potential_split_points = [ + (["transformer.wte", "transformer.wpe"], input_layer_equivalence), + *[([f"transformer.h.{i}"], 1) for i in range(num_layers)], + (["transformer.lm_head_norm", "transformer.lm_head"], output_layer_equivalence), + ] + + return potential_split_points + + class PipelineFactory: """Pipeline factory class to create pipelined models.""" From 6241ea8dbfef66ac7b89e07a0ee350ddb37306a5 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 15 Aug 2025 15:31:34 +0200 Subject: [PATCH 03/60] feat: generic FQN-based PP staging --- .../parallelism/pipeline_parallelism.py | 279 +++++++++--------- 1 file changed, 145 insertions(+), 134 deletions(-) diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py index ac6be437c..0af1551af 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism.py +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -2,162 +2,173 @@ # from Meta's open-source project TorchTitan, # licensed under the BSD 3-Clause License. -import math -from abc import ABC, abstractmethod -from typing import Optional +import copy +from typing import Any, Optional, Type import torch +import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.pipelining import PipelineStage from torch.distributed.pipelining.schedules import PipelineScheduleSingle, get_schedule_class +from modalities.models.parallelism.stages_generator import StagesGenerator from modalities.running_env.fsdp.device_mesh import ParallelismDegrees -class FQNsPerStageGenerator(ABC): - def generate_fqns_per_stage( - self, num_stages: int, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1 - ) -> list[list[str]]: - """ - Generate FQNs for each stage in a GPT-2 model. - - Args: - num_stages (int): Number of stages in the pipeline. - num_layers (int): Total number of layers in the model. - input_layer_equivalence (int): Number of layers corresponding to the input layer. - output_layer_equivalence (int): Number of layers corresponding to the output layer. - - Returns: - list[list[str]]: A list containing FQNs for each stage. - """ - - # Potential split points for GPT-2 model with each potential split point - # listing the FQNs of the modules in that stage and the computational weight. - # The computational weight of the input and output modules are estimated - # based on the number of layers they correspond to. - potential_split_points = self._get_potential_split_points( - num_layers=num_layers, - input_layer_equivalence=input_layer_equivalence, - output_layer_equivalence=output_layer_equivalence, - ) - # Calculate the weight per stage based on the total weight and number of stages - weight_per_stage = math.ceil(sum(weight for _, weight in potential_split_points) / num_stages) - # pack the stages with the layers - next_split_point = 0 - module_names_per_stage: list[list[str]] = [] - for _ in range(num_stages): - stage_fqns = [] - stage_weight = 0 - while next_split_point < len(potential_split_points): - fqns, weight = potential_split_points[next_split_point] - if weight > weight_per_stage: - raise ValueError( - f"Weight of {weight} for {fqns} exceeds weight per stage {weight_per_stage}. " - "Please adjust the number of stages or the weight distribution." - ) - if stage_weight + weight > weight_per_stage: - break - stage_fqns.extend(fqns) - stage_weight += weight - next_split_point += 1 - module_names_per_stage.append(stage_fqns) - - return module_names_per_stage - - @abstractmethod - def _get_potential_split_points( - self, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1 - ) -> list[tuple[list[str], int]]: - """ - Returns a list of potential split points for the GPT-2 model. - - Args: - num_layers (int): Total number of layers in the model. - input_layer_equivalence (int): Number of layers corresponding to the input layer. - output_layer_equivalence (int): Number of layers corresponding to the output layer. - - Returns: - list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. - """ - raise NotImplementedError("This method should be implemented by subclasses.") - - -class GPT2LLMFQNsPerStageGenerator(FQNsPerStageGenerator): - def _get_potential_split_points( - self, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1 - ) -> list[tuple[list[str], int]]: - """ - Returns a list of potential split points for the GPT-2 model. - - Args: - num_layers (int): Total number of layers in the model. - input_layer_equivalence (int): Number of layers corresponding to the input layer. - output_layer_equivalence (int): Number of layers corresponding to the output layer. - - Returns: - list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. - """ - - # Potential split points for GPT-2 model with each potential split point - # listing the FQNs of the modules in that stage and the computational weight. - # The computational weight of the input and output modules are estimated - # based on the number of layers they correspond to. - potential_split_points = [ - (["transformer.wte", "transformer.wpe"], input_layer_equivalence), - *[([f"transformer.h.{i}"], 1) for i in range(num_layers)], - (["transformer.lm_head_norm", "transformer.lm_head"], output_layer_equivalence), - ] - - return potential_split_points - - class PipelineFactory: """Pipeline factory class to create pipelined models.""" @staticmethod - def create_pipeline_model( - num_layers: int, - fqns_per_stage_generator: FQNsPerStageGenerator, + def get_pipelined_model( + whole_model: nn.Module, + stages_generator: StagesGenerator, device_mesh: DeviceMesh, + local_rank: int, pp_schedule_name: str, num_layers_per_stage: int, - input_layer_equivalence: Optional[int] = 1, - output_layer_equivalence: Optional[int] = 1, ) -> torch.nn.Module: - device_mesh[ParallelismDegrees.PP.value] - pp_dims = device_mesh.size(ParallelismDegrees.PP.value) + device = torch.device("cuda", local_rank) + pp_dims = device_mesh[ParallelismDegrees.PP.value].size() + + fqns_per_stage = stages_generator.get_stages( + num_layers_per_stage=num_layers_per_stage, + pp_dims=pp_dims, + ) + + pp_mesh = device_mesh[ParallelismDegrees.PP.value] schedule_class = get_schedule_class(pp_schedule_name) is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) if not is_single_stage_schedule: raise ValueError( f"Unsupported pipeline schedule: {pp_schedule_name}. We only support single-stage schedules." ) - - # calculate the number of stages - num_virtual_stages = math.ceil( - (num_layers + input_layer_equivalence + output_layer_equivalence) / num_layers_per_stage - ) - if num_virtual_stages % pp_dims != 0: - raise ValueError( - f"Number of virtual stages {num_virtual_stages} is not divisible by parallel dimensions {pp_dims}. " - f"For reference: {num_layers=} {input_layer_equivalence=} " - f"{output_layer_equivalence=} {num_layers_per_stage=}" - ) - - stages_per_rank = num_virtual_stages // pp_dims - if stages_per_rank != 1: - raise ValueError( - f"Stages per rank {stages_per_rank} must be 1 for single-stage schedules. " - f"Please adjust {num_layers_per_stage=} to ensure each PP rank has exactly one stage." - ) - - fqns_per_stage_generator.generate_fqns_per_stage( - num_stages=num_virtual_stages, - num_layers=num_layers, - input_layer_equivalence=input_layer_equivalence, - output_layer_equivalence=output_layer_equivalence, + stage, model = PipelineFactory._get_split_model( + whole_model=whole_model, + schedule_class=schedule_class, + pp_mesh=pp_mesh, + device=device, + fqns_per_stage=fqns_per_stage, ) + return whole_model # TODO return pipelined model @staticmethod - def create_gpt2_model_splitter(): - """Create a GPT-2 model splitter for pipeline parallelism.""" - pass + def _get_split_model( + whole_model: nn.Module, + schedule_class: Type[PipelineScheduleSingle], + pp_mesh: DeviceMesh, + device: torch.device, + fqns_per_stage: list[list[str]], + ) -> tuple[PipelineStage, nn.Module]: + def get_stage_id_of_pp_rank(pp_mesh: DeviceMesh): + # NOTE: torch titan a more complicated way to get the stage id of pp rank + # since they also allow for multi-stage schedules + pp_rank = pp_mesh.get_local_rank() + return pp_rank + + @staticmethod + def _get_fqn_tree(fqns: list[str]) -> dict[str, Any]: + fqn_tree = {} + fqns = set(fqns) # Ensure unique FQNs + for fqn in fqns: + parts = fqn.split(".") + current_level = fqn_tree + for part in parts[:-1]: + if part not in current_level: + current_level[part] = {} + elif len(current_level) == 0: + raise ValueError(f"Part {part} of {fqn} already exists " "in the tree as a leaf node.") + current_level = current_level[part] + if parts[-1] in current_level: + raise ValueError( + f" Leaf of {fqn} has already been defined in the tree as an intermediadate node or leaf! " + "Cannot replace the existing node as a leaf." + ) + current_level[parts[-1]] = {} + + return fqn_tree + + def _build_stage_from_modules( + fqn_tree: dict[str, Any], module: nn.Module, module_name: Optional[str] = None + ) -> tuple[PipelineStage, nn.Module]: + if isinstance(module, nn.ModuleDict): + if module_name not in fqn_tree: + dict_modules = nn.ModuleDict({}) + else: + if len(fqn_tree) == 0: + # If the module is a leaf node, we can directly use it + dict_modules = module + else: + # If the module is not a leaf node, we need to build a staged module + # recursively from the FQN tree + dict_modules = {} + dict_module_names = [name for name in module.keys() if name in fqn_tree[module_name]] + for dict_module_name in dict_module_names: + dict_modules[dict_module_name] = _build_stage_from_modules( + fqn_tree=fqn_tree[module_name], + module=module[dict_module_name], + module_name=dict_module_name, + ) + dict_modules = nn.ModuleDict(dict_modules) + # setattr(module, module_name, dict_modules) + return dict_modules + + elif isinstance(module, nn.ModuleList): + if module_name not in fqn_tree: + list_modules = nn.ModuleList([]) + else: + if len(fqn_tree[module_name]) == 0: + # If the module is a leaf node, we can directly use it + list_modules = module + else: + # If the module is not a leaf node, we need to build a staged module + # recursively from the FQN tree + list_modules = [] + list_indices = [i for i in range(len(module)) if str(i) in fqn_tree[module_name]] + for idx in list_indices: + list_modules.append( + _build_stage_from_modules( + fqn_tree=fqn_tree[module_name], module=module[idx], module_name=str(idx) + ) + ) + list_modules = nn.ModuleList(list_modules) + # setattr(module, module_name, list_modules) + return list_modules + + else: # normal nn.Module + if module_name is not None and module_name not in fqn_tree: + # If the module is not in the FQN tree, set it to None + return None + elif module_name is not None and len(fqn_tree[module_name]) == 0: + # If the module is a leaf node, we can directly use it + return module + else: + # If the module is in the FQN tree, we need to build a staged module + # recursively from the FQN tree + for module_name, module_value in module.named_children(): + # If the module is not a leaf node, we need to build a staged module + # recursively from the FQN tree + staged_module = _build_stage_from_modules( + fqn_tree=fqn_tree, module=module_value, module_name=module_name + ) + setattr(module, module_name, staged_module) + + return module + + if not issubclass(schedule_class, PipelineScheduleSingle): + raise NotImplementedError("Only single-stage schedules are supported for pipeline parallelism.") + + # NOTE: For multi-stage schedule, e.g., Interleaved 1F1B, we have multiple stages per pp rank. + # This would need to be adapted accordingly in this case. + stage_idx = get_stage_id_of_pp_rank(pp_mesh) + module_names = fqns_per_stage[stage_idx] + whole_model = copy.deepcopy(whole_model) + fqn_tree = _get_fqn_tree(module_names) + stage_modules = _build_stage_from_modules(fqn_tree, whole_model) + stage = PipelineStage( + submodule=stage_modules, + stage_index=stage_idx, + num_stages=len(fqns_per_stage), + device=device, + group=pp_mesh.get_group("pp"), + ) + return stage, whole_model From 0ba8fbc4e0e8d20623d32844c90725a6beea3d09 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 15 Aug 2025 15:32:08 +0200 Subject: [PATCH 04/60] feat: added PP configs --- .../pipeline_parallelism_configs.py | 22 +++++++++++++++++++ .../parallelism/stages_generator_configs.py | 13 +++++++++++ 2 files changed, 35 insertions(+) create mode 100644 src/modalities/models/parallelism/pipeline_parallelism_configs.py create mode 100644 src/modalities/models/parallelism/stages_generator_configs.py diff --git a/src/modalities/models/parallelism/pipeline_parallelism_configs.py b/src/modalities/models/parallelism/pipeline_parallelism_configs.py new file mode 100644 index 000000000..61b8b5ba4 --- /dev/null +++ b/src/modalities/models/parallelism/pipeline_parallelism_configs.py @@ -0,0 +1,22 @@ +from typing import Annotated + +from pydantic import BaseModel, Field + +from modalities.config.pydantic_if_types import ( + PydanticDeviceMeshIFType, + PydanticPytorchModuleType, + PydanticStagesGeneratorType, +) + + +class FQNsPerStageGeneratorConfig(BaseModel): + pass + + +class PipelinedModelConfig(BaseModel): + whole_model: PydanticPytorchModuleType + stages_generator: PydanticStagesGeneratorType + device_mesh: PydanticDeviceMeshIFType + local_rank: Annotated[int, Field(strict=True, ge=0)] + pp_schedule_name: str + num_layers_per_stage: Annotated[int, Field(strict=True, ge=1)] diff --git a/src/modalities/models/parallelism/stages_generator_configs.py b/src/modalities/models/parallelism/stages_generator_configs.py new file mode 100644 index 000000000..610be7fdd --- /dev/null +++ b/src/modalities/models/parallelism/stages_generator_configs.py @@ -0,0 +1,13 @@ +from typing import Annotated + +from pydantic import BaseModel, Field + + +class FQNsPerStageGeneratorConfig(BaseModel): + pass + + +class GPT2LLMStagesGeneratorConfig(BaseModel): + num_model_layers: Annotated[int, Field(strict=True, ge=1)] + input_layer_equivalence: Annotated[int, Field(strict=True, ge=1)] = 1 + output_layer_equivalence: Annotated[int, Field(strict=True, ge=1)] = 1 From 4a41b6c3648a56a52af86af1a30de34c2215644c Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 15 Aug 2025 15:32:55 +0200 Subject: [PATCH 05/60] feat: wired up PP within dependency graph --- src/modalities/config/pydantic_if_types.py | 2 ++ src/modalities/registry/components.py | 8 +++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/modalities/config/pydantic_if_types.py b/src/modalities/config/pydantic_if_types.py index aa12a444d..eb7d0bce1 100644 --- a/src/modalities/config/pydantic_if_types.py +++ b/src/modalities/config/pydantic_if_types.py @@ -21,6 +21,7 @@ from modalities.inference.text.inference_component import TextInferenceComponent from modalities.logging_broker.subscriber import MessageSubscriberIF from modalities.loss_functions import Loss +from modalities.models.parallelism.pipeline_parallelism import StagesGenerator from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF @@ -83,3 +84,4 @@ def __get_pydantic_core_schema__( PydanticDatasetBatchGeneratorIFType = Annotated[ DatasetBatchGeneratorIF, PydanticThirdPartyTypeIF(DatasetBatchGeneratorIF) ] +PydanticStagesGeneratorType = Annotated[StagesGenerator, PydanticThirdPartyTypeIF(StagesGenerator)] diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 28afab4bb..e6da12819 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -86,6 +86,10 @@ from modalities.models.gpt2.gpt2_model import GPT2LLMConfig from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig from modalities.models.model_factory import GPT2ModelFactory, ModelFactory +from modalities.models.parallelism.pipeline_parallelism import PipelineFactory +from modalities.models.parallelism.pipeline_parallelism_configs import PipelinedModelConfig +from modalities.models.parallelism.stages_generator import GPT2LLMStagesGenerator +from modalities.models.parallelism.stages_generator_configs import GPT2LLMStagesGeneratorConfig from modalities.nn.model_initialization.composed_initialization import ( ComposedInitializationRoutines, ComposedModelInitializationConfig, @@ -174,6 +178,9 @@ class ComponentEntity: ComponentEntity( "model", "debugging_enriched", ModelFactory.get_debugging_enriched_model, DebuggingEnrichedModelConfig ), + ComponentEntity("model", "pipelined", PipelineFactory.get_pipelined_model, PipelinedModelConfig), + # Pipeline Stages Generators + ComponentEntity("stages_generator", "gpt2_stages_generator", GPT2LLMStagesGenerator, GPT2LLMStagesGeneratorConfig), # Device mesh ComponentEntity("device_mesh", "default", get_device_mesh, DeviceMeshConfig), # weight initializers @@ -209,7 +216,6 @@ class ComponentEntity: # tokenizers ComponentEntity("tokenizer", "pretrained_hf_tokenizer", PreTrainedHFTokenizer, PreTrainedHFTokenizerConfig), ComponentEntity("tokenizer", "pretrained_sp_tokenizer", PreTrainedSPTokenizer, PreTrainedSPTokenizerConfig), - # ComponentEntity("tokenizer", "llama_tokenizer_fast", GPT2TokenizerFast, None), # TODO # datasets ComponentEntity("dataset", "mem_map_dataset", DatasetFactory.get_mem_map_dataset, MemMapDatasetConfig), ComponentEntity( From ee529b746da46c0eacba9b771c9a39db70152432 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 15 Aug 2025 15:33:34 +0200 Subject: [PATCH 06/60] feat: added FQN stages generator --- .../models/parallelism/stages_generator.py | 120 ++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 src/modalities/models/parallelism/stages_generator.py diff --git a/src/modalities/models/parallelism/stages_generator.py b/src/modalities/models/parallelism/stages_generator.py new file mode 100644 index 000000000..0a212672a --- /dev/null +++ b/src/modalities/models/parallelism/stages_generator.py @@ -0,0 +1,120 @@ +# Some portions of this implementation are inspired, adapted, or refactored +# from Meta's open-source project TorchTitan, +# licensed under the BSD 3-Clause License. + +import math +from abc import ABC, abstractmethod + + +class StagesGenerator(ABC): + def __init__(self, num_model_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1): + self._num_model_layers = num_model_layers + self._input_layer_equivalence = input_layer_equivalence + self._output_layer_equivalence = output_layer_equivalence + + def get_stages(self, num_layers_per_stage: int, pp_dims: int) -> list[list[str]]: + """ + Generate FQNs for each stage in a GPT-2 model. + + Args: + num_layers_per_stage (int): Number of layers per stage. + pp_dims (int): Number of pipeline parallel dimensions. + + Returns: + list[list[str]]: A list containing FQNs for each stage. + """ + + # calculate the number of stages + num_virtual_stages = math.ceil( + (self._num_model_layers + self._input_layer_equivalence + self._output_layer_equivalence) + / num_layers_per_stage + ) + if num_virtual_stages % pp_dims != 0: + raise ValueError( + f"Number of virtual stages {num_virtual_stages} is not divisible by parallel dimensions {pp_dims}. " + f"For reference: {self._num_model_layers=} {self._input_layer_equivalence=} " + f"{self._output_layer_equivalence=} {num_layers_per_stage=}" + ) + + stages_per_rank = num_virtual_stages // pp_dims + if stages_per_rank != 1: + raise ValueError( + f"Stages per rank {stages_per_rank} must be 1 for single-stage schedules. " + f"Please adjust {num_layers_per_stage=} to ensure each PP rank has exactly one stage." + ) + + # Potential split points for GPT-2 model with each potential split point + # listing the FQNs of the modules in that stage and the computational weight. + # The computational weight of the input and output modules are estimated + # based on the number of layers they correspond to. + potential_split_points = self._get_potential_split_points() + # Calculate the weight per stage based on the total weight and number of stages + weight_per_stage = math.ceil(sum(weight for _, weight in potential_split_points) / num_virtual_stages) + # pack the stages with the layers + next_split_point = 0 + module_names_per_stage: list[list[str]] = [] + for _ in range(num_virtual_stages): + stage_fqns = [] + stage_weight = 0 + while next_split_point < len(potential_split_points): + fqns, weight = potential_split_points[next_split_point] + if weight > weight_per_stage: + raise ValueError( + f"Weight of {weight} for {fqns} exceeds weight per stage {weight_per_stage}. " + "Please adjust the number of stages or the weight distribution." + ) + if stage_weight + weight > weight_per_stage: + break + stage_fqns.extend(fqns) + stage_weight += weight + next_split_point += 1 + module_names_per_stage.append(stage_fqns) + + return module_names_per_stage + + @abstractmethod + def _get_potential_split_points(self) -> list[tuple[list[str], int]]: + """ + Returns a list of potential split points for the GPT-2 model. + + Args: + num_model_layers (int): Total number of layers in the model. + input_layer_equivalence (int): Number of layers corresponding to the input layer. + output_layer_equivalence (int): Number of layers corresponding to the output layer. + + Returns: + list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + + +class GPT2LLMStagesGenerator(StagesGenerator): + def __init__(self, num_model_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1): + super().__init__(num_model_layers, input_layer_equivalence, output_layer_equivalence) + + def _get_potential_split_points( + self, + ) -> list[tuple[list[str], int]]: + """ + Returns a list of potential split points for the GPT-2 model. + + Args: + num_model_layers (int): Total number of layers in the model. + input_layer_equivalence (int): Number of layers corresponding to the input layer. + output_layer_equivalence (int): Number of layers corresponding to the output layer. + + Returns: + list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. + """ + + # Potential split points for GPT-2 model with each potential split point + # listing the FQNs of the modules in that stage and the computational weight. + # The computational weight of the input and output modules are estimated + # based on the number of layers they correspond to. + potential_split_points = [ + (["transformer.wte", "transformer.wpe", "transformer.drop"], self._input_layer_equivalence), + *[([f"transformer.h.{i}"], 1) for i in range(self._num_model_layers)], + (["transformer.lm_head_norm", "transformer.lm_head"], self._output_layer_equivalence), + ] + + return potential_split_points From 625de592c02572db7626168d8504118909b768e1 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Mon, 18 Aug 2025 23:47:00 +0200 Subject: [PATCH 07/60] feat: implemented scheduled pipeline --- .../parallelism/pipeline_parallelism.py | 84 ++++++++++++++++++- 1 file changed, 80 insertions(+), 4 deletions(-) diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py index 0af1551af..e9ac0c755 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism.py +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -3,6 +3,7 @@ # licensed under the BSD 3-Clause License. import copy +from enum import Enum from typing import Any, Optional, Type import torch @@ -11,22 +12,72 @@ from torch.distributed.pipelining import PipelineStage from torch.distributed.pipelining.schedules import PipelineScheduleSingle, get_schedule_class +from modalities.loss_functions import Loss from modalities.models.parallelism.stages_generator import StagesGenerator from modalities.running_env.fsdp.device_mesh import ParallelismDegrees +from modalities.utils.logger_utils import get_logger + +logger = get_logger(__name__) + + +class Pipeline: + def __init__( + self, + stage: PipelineStage, + model: nn.Module, + schedule: Optional[PipelineScheduleSingle] = None, + ): + self._stage = stage + self._model = model + self._schedule = schedule + + @property + def is_first_stage(self) -> bool: + return self._stage.is_first + + @property + def is_last_stage(self) -> bool: + return self._stage.is_last + + @property.setter + def schedule(self, schedule: PipelineScheduleSingle): + self._schedule = schedule + + +class PipelineSelectionTypes(Enum): + """Enum for pipeline selection types.""" + + STAGE = "stage" + MODEL = "model" + SCHEDULE = "schedule" + + +class ComponentSelectorFromPipeline: + @staticmethod + def select(pipeline: Pipeline, selection_type: PipelineSelectionTypes) -> Any: + """Selects a component from the pipeline based on the selection type.""" + if selection_type == PipelineSelectionTypes.STAGE: + return pipeline._stage + elif selection_type == PipelineSelectionTypes.MODEL: + return pipeline._model + elif selection_type == PipelineSelectionTypes.SCHEDULE: + return pipeline._schedule + else: + raise ValueError(f"Unsupported selection type: {selection_type}") class PipelineFactory: """Pipeline factory class to create pipelined models.""" @staticmethod - def get_pipelined_model( + def get_staged_pipeline( whole_model: nn.Module, stages_generator: StagesGenerator, device_mesh: DeviceMesh, local_rank: int, pp_schedule_name: str, num_layers_per_stage: int, - ) -> torch.nn.Module: + ) -> Pipeline: device = torch.device("cuda", local_rank) pp_dims = device_mesh[ParallelismDegrees.PP.value].size() @@ -42,6 +93,10 @@ def get_pipelined_model( raise ValueError( f"Unsupported pipeline schedule: {pp_schedule_name}. We only support single-stage schedules." ) + # torchtitan returns tuple of stages and models as depending on the schedule + # we might have multiple stages and model parts per rank. + # So far we don't support multi-stage schedules, which is why instead of tuples + # we work directly with the stage and model. stage, model = PipelineFactory._get_split_model( whole_model=whole_model, schedule_class=schedule_class, @@ -49,7 +104,9 @@ def get_pipelined_model( device=device, fqns_per_stage=fqns_per_stage, ) - return whole_model # TODO return pipelined model + + pipeline = Pipeline(stage=stage, model=model) + return pipeline @staticmethod def _get_split_model( @@ -171,4 +228,23 @@ def _build_stage_from_modules( device=device, group=pp_mesh.get_group("pp"), ) - return stage, whole_model + return stage, stage_modules + + @staticmethod + def get_scheduled_pipeline( + loss_fn: Loss, pp_schedule_name: str, batch_size: int, microbatch_size: int, pp_degree: int, pipeline: Pipeline + ) -> Pipeline: + # TODO: Addd validation in config that batch_size is divisible by microbatch_size + n_microbatches = batch_size // microbatch_size + num_total_stages = pp_degree + schedule_class = get_schedule_class(pp_schedule_name) + schedule = schedule_class( + stage=pipeline.stage, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + ) + logger.info( + f"Using pipeline schedule {schedule} with {n_microbatches} microbatches and {num_total_stages} stages." + ) + pipeline.schedule = schedule + return pipeline From 9677bd6f09cb5372220c7cdac305f8f99769375a Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Mon, 18 Aug 2025 23:47:42 +0200 Subject: [PATCH 08/60] feat: wired up scheduled and staged pipelines. --- src/modalities/config/pydantic_if_types.py | 3 ++- .../pipeline_parallelism_configs.py | 18 +++++++++++++++++- src/modalities/registry/components.py | 12 +++++++++--- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/modalities/config/pydantic_if_types.py b/src/modalities/config/pydantic_if_types.py index eb7d0bce1..c91ad4549 100644 --- a/src/modalities/config/pydantic_if_types.py +++ b/src/modalities/config/pydantic_if_types.py @@ -21,7 +21,7 @@ from modalities.inference.text.inference_component import TextInferenceComponent from modalities.logging_broker.subscriber import MessageSubscriberIF from modalities.loss_functions import Loss -from modalities.models.parallelism.pipeline_parallelism import StagesGenerator +from modalities.models.parallelism.pipeline_parallelism import Pipeline, StagesGenerator from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF @@ -85,3 +85,4 @@ def __get_pydantic_core_schema__( DatasetBatchGeneratorIF, PydanticThirdPartyTypeIF(DatasetBatchGeneratorIF) ] PydanticStagesGeneratorType = Annotated[StagesGenerator, PydanticThirdPartyTypeIF(StagesGenerator)] +PydanticPipelineType = Annotated[Pipeline, PydanticThirdPartyTypeIF(Pipeline)] diff --git a/src/modalities/models/parallelism/pipeline_parallelism_configs.py b/src/modalities/models/parallelism/pipeline_parallelism_configs.py index 61b8b5ba4..e86cc46be 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism_configs.py +++ b/src/modalities/models/parallelism/pipeline_parallelism_configs.py @@ -4,19 +4,35 @@ from modalities.config.pydantic_if_types import ( PydanticDeviceMeshIFType, + PydanticPipelineType, PydanticPytorchModuleType, PydanticStagesGeneratorType, ) +from modalities.models.parallelism.pipeline_parallelism import PipelineSelectionTypes class FQNsPerStageGeneratorConfig(BaseModel): pass -class PipelinedModelConfig(BaseModel): +class StagedPipelineConfig(BaseModel): whole_model: PydanticPytorchModuleType stages_generator: PydanticStagesGeneratorType device_mesh: PydanticDeviceMeshIFType local_rank: Annotated[int, Field(strict=True, ge=0)] pp_schedule_name: str num_layers_per_stage: Annotated[int, Field(strict=True, ge=1)] + + +class ScheduledPipelineConfig(BaseModel): + loss_fn: PydanticPytorchModuleType + pp_schedule_name: str + batch_size: Annotated[int, Field(strict=True, ge=1)] + microbatch_size: Annotated[int, Field(strict=True, ge=1)] + pp_degree: Annotated[int, Field(strict=True, ge=2)] + pipeline: PydanticPipelineType + + +class ComponentSelectorFromPipelineConfig(BaseModel): + pipeline: PydanticPipelineType + selection_type: PipelineSelectionTypes diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index e6da12819..44d9820c4 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -86,8 +86,12 @@ from modalities.models.gpt2.gpt2_model import GPT2LLMConfig from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig from modalities.models.model_factory import GPT2ModelFactory, ModelFactory -from modalities.models.parallelism.pipeline_parallelism import PipelineFactory -from modalities.models.parallelism.pipeline_parallelism_configs import PipelinedModelConfig +from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory +from modalities.models.parallelism.pipeline_parallelism_configs import ( + ComponentSelectorFromPipelineConfig, + ScheduledPipelineConfig, + StagedPipelineConfig, +) from modalities.models.parallelism.stages_generator import GPT2LLMStagesGenerator from modalities.models.parallelism.stages_generator_configs import GPT2LLMStagesGeneratorConfig from modalities.nn.model_initialization.composed_initialization import ( @@ -178,7 +182,9 @@ class ComponentEntity: ComponentEntity( "model", "debugging_enriched", ModelFactory.get_debugging_enriched_model, DebuggingEnrichedModelConfig ), - ComponentEntity("model", "pipelined", PipelineFactory.get_pipelined_model, PipelinedModelConfig), + ComponentEntity("pipeline", "staged", PipelineFactory.get_staged_pipeline, StagedPipelineConfig), + ComponentEntity("pipeline", "scheduled", PipelineFactory.get_scheduled_pipeline, ScheduledPipelineConfig), + ComponentEntity("pipeline", "selector", ComponentSelectorFromPipeline.select, ComponentSelectorFromPipelineConfig), # Pipeline Stages Generators ComponentEntity("stages_generator", "gpt2_stages_generator", GPT2LLMStagesGenerator, GPT2LLMStagesGeneratorConfig), # Device mesh From 7ac9edfd2578c3ab6c63ea29aed9057dfa22628b Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Mon, 18 Aug 2025 23:48:30 +0200 Subject: [PATCH 09/60] feat: added PP test config --- .../config_lorem_ipsum_long_fsdp2_pp.yaml | 395 ++++++++++++++++++ 1 file changed, 395 insertions(+) create mode 100644 config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml new file mode 100644 index 000000000..e5a3b61ce --- /dev/null +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml @@ -0,0 +1,395 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpoint_saving_path: data/checkpoints + train_dataset_path: ./data/lorem_ipsum_long.pbin + test_dataset_path: ./data/lorem_ipsum.pbin + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 32 + evaluation_interval_in_steps: 32 + consistency_enforcement: + enforce_tokens_per_step_consistency: true + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 1 + sequence_length: 256 + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_packed_mem_map_dataset_continuous + config: + dataset_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + num_ranks: ${settings.cuda_env.world_size} + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: # for the batch progress subscriber + component_key: number_conversion + variant_key: num_steps_from_num_tokens + config: + num_ranks: ${settings.cuda_env.world_size} + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + global_num_tokens: ${settings.training_target.num_target_tokens} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.test_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: test + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + drop_last: true + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: test_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 1 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + # maybe better to use the fsdp model and the schedule here + # instead of passing in the staged pipeline? + # If fsdp_model creates a copy then this is not in the scope of + # the staged pipeline. + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + + + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL + + +staged_pipeline: + component_key: model + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: gpipe + num_layers_per_stage: 2 + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp2 + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + global_rank: ${settings.cuda_env.global_rank} + project: modalities_dcp_tests + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: wandb_storage + config_file_path: ${settings.config_file_path} + +# mfu_calculator: +# component_key: mfu_calculator +# variant_key: gpt2 +# config: +# n_layer: ${model_raw.config.n_layer} +# sequence_length: ${settings.step_profile.sequence_length} +# n_embd: ${model_raw.config.n_embd} +# world_size: ${settings.cuda_env.world_size} +# raw_model: +# instance_key: model_raw +# pass_type: BY_REFERENCE +# wrapped_model: +# instance_key: initialized_model +# pass_type: BY_REFERENCE \ No newline at end of file From d9f63c11d4f0f53a823fc60dc6f016169a185100 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Tue, 19 Aug 2025 14:39:09 +0200 Subject: [PATCH 10/60] refactor: staging is now fully instantiable --- .../config_lorem_ipsum_long_fsdp2_pp.yaml | 34 ++++++++++++++----- src/modalities/config/pydantic_if_types.py | 2 ++ .../parallelism/pipeline_parallelism.py | 27 ++++++++++++--- .../pipeline_parallelism_configs.py | 12 +++++-- .../parallelism/stages_generator_configs.py | 2 +- src/modalities/registry/components.py | 2 ++ 6 files changed, 63 insertions(+), 16 deletions(-) diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml index e5a3b61ce..fa2343b93 100644 --- a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml @@ -24,7 +24,7 @@ settings: enforce_last_step_checkpointed: false step_profile: gradient_accumulation_steps: 1 - local_train_micro_batch_size: 1 + local_train_micro_batch_size: 2 sequence_length: 256 training_target: num_target_tokens: @@ -190,13 +190,19 @@ app_state: instance_key: lr_scheduler pass_type: BY_REFERENCE + initialized_model: component_key: model variant_key: model_initialized config: model: - instance_key: fsdp_model - pass_type: BY_REFERENCE + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL model_initializer: component_key: model_initialization variant_key: composed @@ -223,11 +229,21 @@ scheduled_pipeline: # If fsdp_model creates a copy then this is not in the scope of # the staged pipeline. pipeline: - instance_key: staged_pipeline - pass_type: BY_REFERENCE - - - + component_key: pipeline + variant_key: builder + config: + stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: STAGE + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + fsdp_model: component_key: model variant_key: fsdp2_wrapped @@ -254,7 +270,7 @@ model_part: staged_pipeline: - component_key: model + component_key: pipeline variant_key: staged config: whole_model: diff --git a/src/modalities/config/pydantic_if_types.py b/src/modalities/config/pydantic_if_types.py index c91ad4549..2aeceb53c 100644 --- a/src/modalities/config/pydantic_if_types.py +++ b/src/modalities/config/pydantic_if_types.py @@ -7,6 +7,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import FSDPModule as FSDP2 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP1 +from torch.distributed.pipelining import PipelineStage from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import Sampler @@ -86,3 +87,4 @@ def __get_pydantic_core_schema__( ] PydanticStagesGeneratorType = Annotated[StagesGenerator, PydanticThirdPartyTypeIF(StagesGenerator)] PydanticPipelineType = Annotated[Pipeline, PydanticThirdPartyTypeIF(Pipeline)] +PydanticPipelineStageType = Annotated[PipelineStage, PydanticThirdPartyTypeIF(PipelineStage)] diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py index e9ac0c755..b842fd75c 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism.py +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -39,7 +39,19 @@ def is_first_stage(self) -> bool: def is_last_stage(self) -> bool: return self._stage.is_last - @property.setter + @property + def stage(self) -> PipelineStage: + return self._stage + + @property + def model(self) -> nn.Module: + return self._model + + @property + def schedule(self) -> Optional[PipelineScheduleSingle]: + return self._schedule + + @schedule.setter def schedule(self, schedule: PipelineScheduleSingle): self._schedule = schedule @@ -47,9 +59,9 @@ def schedule(self, schedule: PipelineScheduleSingle): class PipelineSelectionTypes(Enum): """Enum for pipeline selection types.""" - STAGE = "stage" - MODEL = "model" - SCHEDULE = "schedule" + STAGE = "STAGE" + MODEL = "MODEL" + SCHEDULE = "SCHEDULE" class ComponentSelectorFromPipeline: @@ -69,6 +81,12 @@ def select(pipeline: Pipeline, selection_type: PipelineSelectionTypes) -> Any: class PipelineFactory: """Pipeline factory class to create pipelined models.""" + @staticmethod + def get_pipeline( + stage: PipelineStage, model: nn.Module, schedule: Optional[PipelineScheduleSingle] = None + ) -> Pipeline: + return Pipeline(stage=stage, model=model, schedule=schedule) + @staticmethod def get_staged_pipeline( whole_model: nn.Module, @@ -235,6 +253,7 @@ def get_scheduled_pipeline( loss_fn: Loss, pp_schedule_name: str, batch_size: int, microbatch_size: int, pp_degree: int, pipeline: Pipeline ) -> Pipeline: # TODO: Addd validation in config that batch_size is divisible by microbatch_size + # and n_microbatches must be >= pp_degree n_microbatches = batch_size // microbatch_size num_total_stages = pp_degree schedule_class = get_schedule_class(pp_schedule_name) diff --git a/src/modalities/models/parallelism/pipeline_parallelism_configs.py b/src/modalities/models/parallelism/pipeline_parallelism_configs.py index e86cc46be..c1aa23d48 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism_configs.py +++ b/src/modalities/models/parallelism/pipeline_parallelism_configs.py @@ -4,6 +4,8 @@ from modalities.config.pydantic_if_types import ( PydanticDeviceMeshIFType, + PydanticLossIFType, + PydanticPipelineStageType, PydanticPipelineType, PydanticPytorchModuleType, PydanticStagesGeneratorType, @@ -11,7 +13,7 @@ from modalities.models.parallelism.pipeline_parallelism import PipelineSelectionTypes -class FQNsPerStageGeneratorConfig(BaseModel): +class FQNsPerStageGeneratorConfig(BaseModel): # TODO duplicate pass @@ -25,7 +27,7 @@ class StagedPipelineConfig(BaseModel): class ScheduledPipelineConfig(BaseModel): - loss_fn: PydanticPytorchModuleType + loss_fn: PydanticLossIFType pp_schedule_name: str batch_size: Annotated[int, Field(strict=True, ge=1)] microbatch_size: Annotated[int, Field(strict=True, ge=1)] @@ -36,3 +38,9 @@ class ScheduledPipelineConfig(BaseModel): class ComponentSelectorFromPipelineConfig(BaseModel): pipeline: PydanticPipelineType selection_type: PipelineSelectionTypes + + +class PipelineConfig(BaseModel): + stage: PydanticPipelineStageType + model: PydanticPytorchModuleType + schedule: PydanticPipelineType | None = None diff --git a/src/modalities/models/parallelism/stages_generator_configs.py b/src/modalities/models/parallelism/stages_generator_configs.py index 610be7fdd..5d53f091d 100644 --- a/src/modalities/models/parallelism/stages_generator_configs.py +++ b/src/modalities/models/parallelism/stages_generator_configs.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field -class FQNsPerStageGeneratorConfig(BaseModel): +class FQNsPerStageGeneratorConfig(BaseModel): # TODO duplicate pass diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 44d9820c4..167a29894 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -89,6 +89,7 @@ from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory from modalities.models.parallelism.pipeline_parallelism_configs import ( ComponentSelectorFromPipelineConfig, + PipelineConfig, ScheduledPipelineConfig, StagedPipelineConfig, ) @@ -185,6 +186,7 @@ class ComponentEntity: ComponentEntity("pipeline", "staged", PipelineFactory.get_staged_pipeline, StagedPipelineConfig), ComponentEntity("pipeline", "scheduled", PipelineFactory.get_scheduled_pipeline, ScheduledPipelineConfig), ComponentEntity("pipeline", "selector", ComponentSelectorFromPipeline.select, ComponentSelectorFromPipelineConfig), + ComponentEntity("pipeline", "builder", PipelineFactory.get_pipeline, PipelineConfig), # Pipeline Stages Generators ComponentEntity("stages_generator", "gpt2_stages_generator", GPT2LLMStagesGenerator, GPT2LLMStagesGeneratorConfig), # Device mesh From 83c87b9d6d6fbbb228bab31dccf1870b12679775 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Tue, 19 Aug 2025 14:39:58 +0200 Subject: [PATCH 11/60] feat: drafted pp e2e test for fwd/bwd pass --- .../pipeline_parallelism/__init__.py | 0 ...orem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml | 171 ++++++++++++++++++ .../test_pp_fwd_bwd_pass.py | 104 +++++++++++ 3 files changed, 275 insertions(+) create mode 100644 tests/fsdp2_parallelization/pipeline_parallelism/__init__.py create mode 100644 tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml create mode 100644 tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/__init__.py b/tests/fsdp2_parallelization/pipeline_parallelism/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml new file mode 100644 index 000000000..88182d266 --- /dev/null +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml @@ -0,0 +1,171 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 2 + sequence_length: 256 + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 1 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + pipeline: + component_key: pipeline + variant_key: builder + config: + stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: STAGE + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: gpipe + num_layers_per_stage: 2 + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py new file mode 100644 index 000000000..fc24223e9 --- /dev/null +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -0,0 +1,104 @@ +import os +import tempfile +from pathlib import Path + +import pytest +import torch +import torch.multiprocessing as mp +import yaml +from pydantic import BaseModel + +from modalities.__main__ import Main +from modalities.config.config import ProcessGroupBackendType +from modalities.config.pydantic_if_types import PydanticFSDP2ModuleType, PydanticPipelineType +from tests.end2end_tests.custom_components import MultiProcessingCudaEnv + + +@pytest.fixture +def temp_file_path() -> Path: + # Create a NamedTemporaryFile that persists after closing (delete=False) + with tempfile.NamedTemporaryFile(delete=False) as tf: + file_path = tf.name + try: + yield Path(file_path) + finally: + # Clean up the file after the test + if os.path.exists(file_path): + os.remove(file_path) + + +class ComponentsInstantiationModel(BaseModel): + initialized_model: PydanticFSDP2ModuleType + scheduled_pipeline: PydanticPipelineType + + +@pytest.mark.skipif( + torch.cuda.device_count() < 8, + reason="This test requires 8 GPUs", +) +class TestPipelineParallelism: + def _get_tmp_sharding_config_path( + self, sharding_degree: int, tp_degree: int, pp_degree: int, temp_file_path: Path + ) -> Path: + working_dir = Path(os.path.dirname(__file__)) + config_file_path = working_dir / "configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml" + + with open(config_file_path, "r") as file: + config_string = file.read() + config_dict = yaml.safe_load(config_string) + config_dict["device_mesh"]["config"]["data_parallel_shard_degree"] = sharding_degree + config_dict["device_mesh"]["config"]["tensor_parallel_degree"] = tp_degree + config_dict["device_mesh"]["config"]["pipeline_parallel_degree"] = pp_degree + + # save to temporary file + with open(temp_file_path, "w") as file: + yaml.dump(config_dict, file) + + return temp_file_path + + def _get_components(self, config_file_path: Path) -> ComponentsInstantiationModel: + main_obj = Main(config_file_path) + components: ComponentsInstantiationModel = main_obj.build_components( + components_model_type=ComponentsInstantiationModel + ) + return components + + @pytest.mark.parametrize( + "sharding_degree, tp_degree, pp_degree, world_size", + [ + (2, 1, 2, 4), + # (2, 1, 4, 8), + # (2, 2, 2, 8), # TODO need to support this case + ], + ) + def test_pp(self, sharding_degree: int, tp_degree: int, pp_degree: int, world_size: int, temp_file_path: Path): + tmp_sharding_config_path = self._get_tmp_sharding_config_path( + sharding_degree=sharding_degree, + tp_degree=tp_degree, + pp_degree=pp_degree, + temp_file_path=temp_file_path, + ) + mp.spawn( + self._test_pp_impl, + args=(world_size, sharding_degree, tmp_sharding_config_path), + nprocs=world_size, + join=True, + ) + + def _test_pp_impl( + self, + process_id: int, + world_size: int, + sharding_degree: int, + gpt2_model_config_path: Path, + ): + # wraps the actual test function to be able to run it in a distributed multiprocessing setup + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=process_id, + local_rank=process_id, + world_size=world_size, + rdvz_port=22356, + ): + self._get_components(gpt2_model_config_path) + pass From 95f24701fc9940e565893668e6d07cd6dc93b3ca Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 29 Aug 2025 09:55:12 +0200 Subject: [PATCH 12/60] refactor: renamings in the context of PP --- .../parallelism/pipeline_parallelism.py | 74 +++++++++---------- .../pipeline_parallelism_configs.py | 6 +- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py index b842fd75c..006d97a55 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism.py +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -23,57 +23,57 @@ class Pipeline: def __init__( self, - stage: PipelineStage, - model: nn.Module, - schedule: Optional[PipelineScheduleSingle] = None, + pp_stage: PipelineStage, + model_part: nn.Module, + pp_schedule: Optional[PipelineScheduleSingle] = None, ): - self._stage = stage - self._model = model - self._schedule = schedule + self._pp_stage = pp_stage + self._model_part = model_part + self._pp_schedule = pp_schedule @property - def is_first_stage(self) -> bool: - return self._stage.is_first + def is_first_pp_stage(self) -> bool: + return self._pp_stage.is_first @property - def is_last_stage(self) -> bool: - return self._stage.is_last + def is_last_pp_stage(self) -> bool: + return self._pp_stage.is_last @property - def stage(self) -> PipelineStage: - return self._stage + def pp_stage(self) -> PipelineStage: + return self._pp_stage @property - def model(self) -> nn.Module: - return self._model + def model_part(self) -> nn.Module: + return self._model_part @property - def schedule(self) -> Optional[PipelineScheduleSingle]: - return self._schedule + def pp_schedule(self) -> Optional[PipelineScheduleSingle]: + return self._pp_schedule - @schedule.setter - def schedule(self, schedule: PipelineScheduleSingle): - self._schedule = schedule + @pp_schedule.setter + def pp_schedule(self, schedule: PipelineScheduleSingle): + self._pp_schedule = schedule class PipelineSelectionTypes(Enum): """Enum for pipeline selection types.""" - STAGE = "STAGE" - MODEL = "MODEL" - SCHEDULE = "SCHEDULE" + PP_STAGE = "PP_STAGE" + MODEL_PART = "MODEL_PART" + PP_SCHEDULE = "PP_SCHEDULE" class ComponentSelectorFromPipeline: @staticmethod def select(pipeline: Pipeline, selection_type: PipelineSelectionTypes) -> Any: """Selects a component from the pipeline based on the selection type.""" - if selection_type == PipelineSelectionTypes.STAGE: - return pipeline._stage - elif selection_type == PipelineSelectionTypes.MODEL: - return pipeline._model - elif selection_type == PipelineSelectionTypes.SCHEDULE: - return pipeline._schedule + if selection_type == PipelineSelectionTypes.PP_STAGE: + return pipeline.pp_stage + elif selection_type == PipelineSelectionTypes.MODEL_PART: + return pipeline.model_part + elif selection_type == PipelineSelectionTypes.PP_SCHEDULE: + return pipeline.pp_schedule else: raise ValueError(f"Unsupported selection type: {selection_type}") @@ -83,9 +83,9 @@ class PipelineFactory: @staticmethod def get_pipeline( - stage: PipelineStage, model: nn.Module, schedule: Optional[PipelineScheduleSingle] = None + pp_stage: PipelineStage, model_part: nn.Module, pp_schedule: Optional[PipelineScheduleSingle] = None ) -> Pipeline: - return Pipeline(stage=stage, model=model, schedule=schedule) + return Pipeline(pp_stage=pp_stage, model_part=model_part, pp_schedule=pp_schedule) @staticmethod def get_staged_pipeline( @@ -115,7 +115,7 @@ def get_staged_pipeline( # we might have multiple stages and model parts per rank. # So far we don't support multi-stage schedules, which is why instead of tuples # we work directly with the stage and model. - stage, model = PipelineFactory._get_split_model( + pp_stage, model_part = PipelineFactory._get_split_model( whole_model=whole_model, schedule_class=schedule_class, pp_mesh=pp_mesh, @@ -123,7 +123,7 @@ def get_staged_pipeline( fqns_per_stage=fqns_per_stage, ) - pipeline = Pipeline(stage=stage, model=model) + pipeline = Pipeline(pp_stage=pp_stage, model_part=model_part) return pipeline @staticmethod @@ -256,14 +256,14 @@ def get_scheduled_pipeline( # and n_microbatches must be >= pp_degree n_microbatches = batch_size // microbatch_size num_total_stages = pp_degree - schedule_class = get_schedule_class(pp_schedule_name) - schedule = schedule_class( - stage=pipeline.stage, + pp_schedule_class = get_schedule_class(pp_schedule_name) + pp_schedule = pp_schedule_class( + stage=pipeline.pp_stage, n_microbatches=n_microbatches, loss_fn=loss_fn, ) logger.info( - f"Using pipeline schedule {schedule} with {n_microbatches} microbatches and {num_total_stages} stages." + f"Using pipeline schedule {pp_schedule} with {n_microbatches} microbatches and {num_total_stages} stages." ) - pipeline.schedule = schedule + pipeline.pp_schedule = pp_schedule return pipeline diff --git a/src/modalities/models/parallelism/pipeline_parallelism_configs.py b/src/modalities/models/parallelism/pipeline_parallelism_configs.py index c1aa23d48..831a6e15e 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism_configs.py +++ b/src/modalities/models/parallelism/pipeline_parallelism_configs.py @@ -41,6 +41,6 @@ class ComponentSelectorFromPipelineConfig(BaseModel): class PipelineConfig(BaseModel): - stage: PydanticPipelineStageType - model: PydanticPytorchModuleType - schedule: PydanticPipelineType | None = None + pp_stage: PydanticPipelineStageType + model_part: PydanticPytorchModuleType + pp_schedule: PydanticPipelineType | None = None From 521e5867559c984c71ab98b12d58a349c66d69cd Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 29 Aug 2025 09:56:39 +0200 Subject: [PATCH 13/60] chore: drafted the first PP test. --- ...orem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml | 18 +++++----- .../test_pp_fwd_bwd_pass.py | 34 ++++++++++++++++--- 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml index 88182d266..0ceb02a53 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml @@ -11,7 +11,7 @@ settings: world_size: ${cuda_env:WORLD_SIZE} step_profile: gradient_accumulation_steps: 1 - local_train_micro_batch_size: 2 + local_train_micro_batch_size: 4 sequence_length: 256 loss_fn: @@ -42,7 +42,7 @@ initialized_model: pipeline: instance_key: scheduled_pipeline pass_type: BY_REFERENCE - selection_type: MODEL + selection_type: MODEL_PART model_initializer: component_key: model_initialization variant_key: composed @@ -62,21 +62,21 @@ scheduled_pipeline: pass_type: BY_REFERENCE pp_schedule_name: gpipe batch_size: ${settings.step_profile.local_train_micro_batch_size} - microbatch_size: 1 + microbatch_size: 2 pp_degree: ${device_mesh.config.pipeline_parallel_degree} pipeline: component_key: pipeline variant_key: builder config: - stage: + pp_stage: component_key: pipeline variant_key: selector config: pipeline: instance_key: staged_pipeline pass_type: BY_REFERENCE - selection_type: STAGE - model: + selection_type: PP_STAGE + model_part: instance_key: fsdp_model pass_type: BY_REFERENCE @@ -102,7 +102,7 @@ model_part: pipeline: instance_key: staged_pipeline pass_type: BY_REFERENCE - selection_type: MODEL + selection_type: MODEL_PART staged_pipeline: component_key: pipeline @@ -123,7 +123,7 @@ staged_pipeline: pass_type: BY_REFERENCE local_rank: ${settings.cuda_env.local_rank} pp_schedule_name: gpipe - num_layers_per_stage: 2 + num_layers_per_stage: 4 model_raw: component_key: model @@ -136,7 +136,7 @@ model_raw: sequence_length: ${settings.step_profile.sequence_length} prediction_key: ${loss_fn.config.prediction_key} vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency - n_layer: 2 + n_layer: 6 n_head_q: 8 n_head_kv: 4 ffn_hidden: 128 diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py index fc24223e9..6f861c1ea 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -11,6 +11,7 @@ from modalities.__main__ import Main from modalities.config.config import ProcessGroupBackendType from modalities.config.pydantic_if_types import PydanticFSDP2ModuleType, PydanticPipelineType +from modalities.models.parallelism.pipeline_parallelism import Pipeline from tests.end2end_tests.custom_components import MultiProcessingCudaEnv @@ -80,7 +81,7 @@ def test_pp(self, sharding_degree: int, tp_degree: int, pp_degree: int, world_si ) mp.spawn( self._test_pp_impl, - args=(world_size, sharding_degree, tmp_sharding_config_path), + args=(world_size, tmp_sharding_config_path), nprocs=world_size, join=True, ) @@ -89,7 +90,6 @@ def _test_pp_impl( self, process_id: int, world_size: int, - sharding_degree: int, gpt2_model_config_path: Path, ): # wraps the actual test function to be able to run it in a distributed multiprocessing setup @@ -100,5 +100,31 @@ def _test_pp_impl( world_size=world_size, rdvz_port=22356, ): - self._get_components(gpt2_model_config_path) - pass + components = self._get_components(gpt2_model_config_path) + scheduled_pipeline = components.scheduled_pipeline + vocab_size = 50304 + sequence_length = 256 + batch_size = 4 + sequences = torch.randint(0, vocab_size, (batch_size, sequence_length)) + targets = sequences[:, 1:].contiguous() + inputs = sequences[:, :-1].contiguous() + self._forward_step(scheduled_pipeline, inputs, targets) + + def _forward_step(self, scheduled_pipeline: Pipeline, inputs: torch.Tensor, targets: torch.Tensor): + """Runs a forward step on the model.""" + pp_schedule = scheduled_pipeline.pp_schedule + targets, losses = (targets, []) if scheduled_pipeline.is_last_pp_stage else (None, None) + if scheduled_pipeline.is_first_pp_stage: # first stage + pp_schedule.step(inputs, target=targets, losses=losses, input_batch=inputs) + else: # non-first stage + pp_schedule.step(target=targets, losses=losses, input_batch=inputs) + + # accumulate losses across pipeline microbatches + # TODO: PP+FSDP unexpectedly puts the loss back to the CPU + ( + torch.mean(torch.stack(losses)).to(self.device) + if self.pp_has_last_stage + else torch.tensor([-1.0], device=self.device) + ) + + # return output From 002b0ae557411351dc274be97f7e0e6c59c0afd8 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Sat, 30 Aug 2025 00:46:18 +0200 Subject: [PATCH 14/60] chore: pp config fixes --- .../training/config_lorem_ipsum_long_fsdp2_pp.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml index fa2343b93..381550a20 100644 --- a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml @@ -202,7 +202,7 @@ initialized_model: pipeline: instance_key: scheduled_pipeline pass_type: BY_REFERENCE - selection_type: MODEL + selection_type: MODEL_PART model_initializer: component_key: model_initialization variant_key: composed @@ -232,15 +232,15 @@ scheduled_pipeline: component_key: pipeline variant_key: builder config: - stage: + pp_stage: component_key: pipeline variant_key: selector config: pipeline: instance_key: staged_pipeline pass_type: BY_REFERENCE - selection_type: STAGE - model: + selection_type: PP_STAGE + model_part: instance_key: fsdp_model pass_type: BY_REFERENCE @@ -266,7 +266,7 @@ model_part: pipeline: instance_key: staged_pipeline pass_type: BY_REFERENCE - selection_type: MODEL + selection_type: MODEL_PART staged_pipeline: From 1d4943f5c065af94164059739dca15f0c2f72049 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 5 Sep 2025 18:14:47 +0200 Subject: [PATCH 15/60] feat: Make test for pipeline parallelism work --- src/modalities/loss_functions.py | 8 ++ src/modalities/models/gpt2/gpt2_model.py | 55 +++++++++ src/modalities/models/model_factory.py | 7 +- src/modalities/registry/components.py | 3 +- ...g_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml | 108 ++++++++++++++++++ ...orem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml | 3 +- .../test_pp_fwd_bwd_pass.py | 58 +++++++--- 7 files changed, 221 insertions(+), 21 deletions(-) create mode 100644 tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index 54d8de36b..f46fb0398 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -44,6 +44,14 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: return loss +class CLMCrossEntropyLossPP(CLMCrossEntropyLoss): + def __call__(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + forward_batch = InferenceResultBatch( + predictions={self.prediction_key: outputs}, targets={self.target_key: targets} + ) + return super().__call__(forward_batch) + + def nce_loss( embedding1: torch.Tensor, embedding2: torch.Tensor, device: torch.device, is_asymmetric: bool, temperature: float ) -> torch.Tensor: diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index a2022d716..76b0399ae 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -344,6 +344,7 @@ class GPT2LLMConfig(BaseModel): ffn_norm_config: LayerNormWrapperConfig lm_head_norm_config: LayerNormWrapperConfig use_weight_tying: bool + use_pp: Optional[bool] = False @model_validator(mode="after") def check_divisibility(self) -> "GPT2LLMConfig": @@ -930,6 +931,60 @@ def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: return self.forward_impl(inputs) +class GPT2LLMPP(GPT2LLM): + """GPT2LLM class.""" + + def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Forward pass implementation of the GPT2LLM module. + + Args: + inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. + - sample_key (str): Key for the input tensor containing token ids. + + Returns: + dict[str, torch.Tensor]: A dictionary containing output tensors. + - prediction_key (str): Key for the output tensor containing logits. + """ + device = inputs.device + t = inputs.size(1) # batch size, sequence length + assert t <= self.sequence_length, f"Cannot forward sequence of length {t}, the model's maximum " + f"input sequence length is only {self.sequence_length}" + + # forward the GPT model itself + h = ( + self.transformer.wte(inputs) if hasattr(self.transformer, "wte") else inputs + ) # token embeddings of shape (b, t, n_embd) + + if self.poe_type is PositionTypes.ABSOLUTE and hasattr(self.transformer, "wpe"): + pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) + h = h + pos_emb + + # TODO: use drop out also without absolute position embedding? + h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h + + for block in self.transformer.h: + h = block(h) + h = self.transformer.lm_head_norm(h) if hasattr(self.transformer, "lm_head_norm") else h + h = self.transformer.lm_head(h) if hasattr(self.transformer, "lm_head") else h + return h + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the GPT2LLM module. + + Args: + inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. + - sample_key (str): Key for the input tensor containing token ids. + + Returns: + dict[str, torch.Tensor]: A dictionary containing output tensors. + - prediction_key (str): Key for the output tensor containing logits. + """ + return self.forward_impl(inputs) + + def manual_scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None ) -> torch.Tensor: diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 877c9cbdc..dc3d084a8 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -28,6 +28,7 @@ from modalities.exceptions import ModelStateError from modalities.models.gpt2.gpt2_model import ( GPT2LLM, + GPT2LLMPP, AttentionConfig, AttentionImplementation, LayerNormWrapperConfig, @@ -568,6 +569,7 @@ def get_gpt2_model( use_weight_tying: bool, use_meta_device: Optional[bool] = False, seed: int = None, + use_pp: Optional[bool] = False, ) -> GPT2LLM: config = dict( sample_key=sample_key, @@ -597,11 +599,12 @@ def get_gpt2_model( "Please set at least use_meta_device=False or use_weight_tying=False." "https://github.com/Modalities/modalities/issues/357" ) + gpt2_model_class = GPT2LLMPP if use_pp else GPT2LLM if use_meta_device: with torch.device("meta"): - model = GPT2LLM(**config) + model = gpt2_model_class(**config) else: - model = GPT2LLM(**config) + model = gpt2_model_class(**config) return model @staticmethod diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 167a29894..b3a0a7618 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -78,7 +78,7 @@ ProgressSubscriberFactory, ResultsSubscriberFactory, ) -from modalities.loss_functions import CLMCrossEntropyLoss +from modalities.loss_functions import CLMCrossEntropyLoss, CLMCrossEntropyLossPP from modalities.models.coca.coca_model import CoCa, CoCaConfig from modalities.models.coca.collator import CoCaCollateFnConfig, CoCaCollatorFn from modalities.models.components.layer_norms import LayerNormConfig, RMSLayerNorm, RMSLayerNormConfig @@ -200,6 +200,7 @@ class ComponentEntity: ), # losses ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig), + ComponentEntity("loss", "clm_cross_entropy_loss_pp", CLMCrossEntropyLossPP, CLMCrossEntropyLossConfig), # optmizers ComponentEntity("optimizer", "adam", OptimizerFactory.get_adam, AdamOptimizerConfig), ComponentEntity("optimizer", "adam_w", OptimizerFactory.get_adam_w, AdamWOptimizerConfig), diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml new file mode 100644 index 000000000..6603b1850 --- /dev/null +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml @@ -0,0 +1,108 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 4 + sequence_length: 256 + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 6 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml index 0ceb02a53..5ef5f148a 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml @@ -16,7 +16,7 @@ settings: loss_fn: component_key: loss - variant_key: clm_cross_entropy_loss + variant_key: clm_cross_entropy_loss_pp config: target_key: ${settings.referencing_keys.target_key} prediction_key: ${settings.referencing_keys.prediction_key} @@ -129,6 +129,7 @@ model_raw: component_key: model variant_key: gpt2 config: + use_pp: true use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py index 6f861c1ea..7384b8338 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -9,8 +9,9 @@ from pydantic import BaseModel from modalities.__main__ import Main +from modalities.batch import InferenceResultBatch from modalities.config.config import ProcessGroupBackendType -from modalities.config.pydantic_if_types import PydanticFSDP2ModuleType, PydanticPipelineType +from modalities.config.pydantic_if_types import PydanticFSDP2ModuleType, PydanticLossIFType, PydanticPipelineType from modalities.models.parallelism.pipeline_parallelism import Pipeline from tests.end2end_tests.custom_components import MultiProcessingCudaEnv @@ -28,11 +29,16 @@ def temp_file_path() -> Path: os.remove(file_path) -class ComponentsInstantiationModel(BaseModel): +class ComponentsInstantiationPPModel(BaseModel): initialized_model: PydanticFSDP2ModuleType scheduled_pipeline: PydanticPipelineType +class ComponentsInstantiationModel(BaseModel): + initialized_model: PydanticFSDP2ModuleType + loss_fn: PydanticLossIFType + + @pytest.mark.skipif( torch.cuda.device_count() < 8, reason="This test requires 8 GPUs", @@ -57,11 +63,14 @@ def _get_tmp_sharding_config_path( return temp_file_path - def _get_components(self, config_file_path: Path) -> ComponentsInstantiationModel: + def _get_components(self, config_file_path: Path, use_pp: bool) -> ComponentsInstantiationPPModel: + torch.manual_seed(42) main_obj = Main(config_file_path) - components: ComponentsInstantiationModel = main_obj.build_components( - components_model_type=ComponentsInstantiationModel - ) + if use_pp: + components_model_type = ComponentsInstantiationPPModel + else: + components_model_type = ComponentsInstantiationModel + components: components_model_type = main_obj.build_components(components_model_type=components_model_type) return components @pytest.mark.parametrize( @@ -90,7 +99,7 @@ def _test_pp_impl( self, process_id: int, world_size: int, - gpt2_model_config_path: Path, + pp_model_config_path: Path, ): # wraps the actual test function to be able to run it in a distributed multiprocessing setup with MultiProcessingCudaEnv( @@ -100,7 +109,7 @@ def _test_pp_impl( world_size=world_size, rdvz_port=22356, ): - components = self._get_components(gpt2_model_config_path) + components = self._get_components(pp_model_config_path, use_pp=True) scheduled_pipeline = components.scheduled_pipeline vocab_size = 50304 sequence_length = 256 @@ -108,23 +117,38 @@ def _test_pp_impl( sequences = torch.randint(0, vocab_size, (batch_size, sequence_length)) targets = sequences[:, 1:].contiguous() inputs = sequences[:, :-1].contiguous() - self._forward_step(scheduled_pipeline, inputs, targets) + loss_pp = self._forward_step(scheduled_pipeline, inputs, targets) + + # if scheduled_pipeline.is_last_pp_stage: + working_dir = Path(os.path.dirname(__file__)) + fsdp2_model_config_path = working_dir / "configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml" + fsdp2_components = self._get_components(fsdp2_model_config_path, use_pp=False) + fsdp2_model = fsdp2_components.initialized_model + fsdp2_loss_fn = fsdp2_components.loss_fn + + input_dict = {"input_ids": inputs} + fsdp2_out = fsdp2_model(input_dict) + forward_batch = InferenceResultBatch(predictions=fsdp2_out, targets={fsdp2_loss_fn.target_key: targets}) + fsdp2_loss = fsdp2_loss_fn(forward_batch) + if scheduled_pipeline.is_last_pp_stage: + assert torch.allclose(fsdp2_loss, loss_pp, atol=1e-6, rtol=1e-5), "Outputs do not match" def _forward_step(self, scheduled_pipeline: Pipeline, inputs: torch.Tensor, targets: torch.Tensor): """Runs a forward step on the model.""" pp_schedule = scheduled_pipeline.pp_schedule targets, losses = (targets, []) if scheduled_pipeline.is_last_pp_stage else (None, None) if scheduled_pipeline.is_first_pp_stage: # first stage - pp_schedule.step(inputs, target=targets, losses=losses, input_batch=inputs) + # pp_schedule.step(inputs, target=targets, losses=losses, input_batch=inputs) + pp_schedule.step(inputs, target=targets, losses=losses) else: # non-first stage - pp_schedule.step(target=targets, losses=losses, input_batch=inputs) + # pp_schedule.step(target=targets, losses=losses, input_batch=inputs) + # pp_schedule.step(inputs, target=targets, losses=losses, input_batch=inputs) + pp_schedule.step(target=targets, losses=losses) # accumulate losses across pipeline microbatches # TODO: PP+FSDP unexpectedly puts the loss back to the CPU - ( - torch.mean(torch.stack(losses)).to(self.device) - if self.pp_has_last_stage - else torch.tensor([-1.0], device=self.device) + return ( + torch.mean(torch.stack(losses)).to(losses[0].device) + if scheduled_pipeline.is_last_pp_stage + else torch.tensor([-1.0], device=inputs.device) ) - - # return output From 5b53ff97b3df6b780fbf747dfa6293359d9c7f43 Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Mon, 8 Sep 2025 12:29:55 +0200 Subject: [PATCH 16/60] refactor(parallelism): Removed necessity of additional model and loss classes for pipeline parallelism. --- src/modalities/loss_functions.py | 53 ++++++++++-- src/modalities/models/gpt2/gpt2_model.py | 82 ++++++------------- src/modalities/models/model_factory.py | 7 +- src/modalities/registry/components.py | 3 +- ...orem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml | 3 +- 5 files changed, 71 insertions(+), 77 deletions(-) diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index f46fb0398..e3be6100d 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import overload import torch from torch.nn import CrossEntropyLoss @@ -31,9 +32,16 @@ def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEnt # Mean over the tokens in the local-batch (batch per rank) self.loss_fun = CrossEntropyLoss(reduction="mean") + @overload def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: - labels = forward_batch.get_targets(self.target_key) - lm_logits = forward_batch.get_predictions(self.prediction_key) + ... + + @overload + def __call__(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + ... + + def __call__(self, *args, **kwargs) -> torch.Tensor: + labels, lm_logits = self._parse_arguments(args, kwargs) # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) @@ -43,13 +51,40 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: loss = self.loss_fun(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return loss - -class CLMCrossEntropyLossPP(CLMCrossEntropyLoss): - def __call__(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: - forward_batch = InferenceResultBatch( - predictions={self.prediction_key: outputs}, targets={self.target_key: targets} - ) - return super().__call__(forward_batch) + def _parse_arguments( + self, + args: list[torch.Tensor] | list[InferenceResultBatch], + kwargs: dict[str, torch.Tensor] | dict[str, InferenceResultBatch], + ) -> tuple[torch.Tensor, torch.Tensor]: + if len(args) == 1 and isinstance(args[0], InferenceResultBatch): + forward_batch = args[0] + labels = forward_batch.get_targets(self.target_key) + lm_logits = forward_batch.get_predictions(self.prediction_key) + elif "forward_batch" in kwargs and isinstance(kwargs["forward_batch"], InferenceResultBatch): + forward_batch = kwargs["forward_batch"] + labels = forward_batch.get_targets(self.target_key) + lm_logits = forward_batch.get_predictions(self.prediction_key) + elif len(args) == 2 and all(isinstance(arg, torch.Tensor) for arg in args): + lm_logits, labels = args + elif ( + "outputs" in kwargs + and "targets" in kwargs + and isinstance(kwargs["outputs"], torch.Tensor) + and isinstance(kwargs["targets"], torch.Tensor) + ): + lm_logits = kwargs["outputs"] + labels = kwargs["targets"] + elif ( + len(args) == 1 + and "targets" in kwargs + and isinstance(args[0], torch.Tensor) + and isinstance(kwargs["targets"], torch.Tensor) + ): + lm_logits = args[0] + labels = kwargs["targets"] + else: + raise TypeError("Invalid arguments for CLMCrossEntropyLoss.__call__") + return labels, lm_logits def nce_loss( diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 76b0399ae..bdbd49913 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -2,7 +2,7 @@ import math from abc import abstractmethod from enum import Enum -from typing import Annotated, Optional +from typing import Annotated, Optional, overload import torch import torch.nn as nn @@ -344,7 +344,6 @@ class GPT2LLMConfig(BaseModel): ffn_norm_config: LayerNormWrapperConfig lm_head_norm_config: LayerNormWrapperConfig use_weight_tying: bool - use_pp: Optional[bool] = False @model_validator(mode="after") def check_divisibility(self) -> "GPT2LLMConfig": @@ -881,9 +880,10 @@ def __init__( self.transformer.lm_head.weight ) # https://paperswithcode.com/method/weight-tying - def forward_impl(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + @overload + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ - Forward pass implementation of the GPT2LLM module. + Forward pass of the GPT2LLM module. Args: inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. @@ -893,72 +893,50 @@ def forward_impl(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tenso dict[str, torch.Tensor]: A dictionary containing output tensors. - prediction_key (str): Key for the output tensor containing logits. """ - input_ids = inputs[self.sample_key] - device = input_ids.device - _, t = input_ids.size() # batch size, sequence length - assert t <= self.sequence_length, f"Cannot forward sequence of length {t}, the model's maximum " - f"input sequence length is only {self.sequence_length}" - - # forward the GPT model itself - tok_emb = self.transformer.wte(input_ids) # token embeddings of shape (b, t, n_embd) - - if self.poe_type is PositionTypes.ABSOLUTE: - pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) - tok_emb = tok_emb + pos_emb - - # TODO: use drop out also without absolute position embedding? - x = self.transformer.drop(tok_emb) - - for block in self.transformer.h: - x = block(x) - x = self.transformer.lm_head_norm(x) - logits = self.transformer.lm_head(x) - return {self.prediction_key: logits} + ... - def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + @overload + def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ Forward pass of the GPT2LLM module. Args: - inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. - - sample_key (str): Key for the input tensor containing token ids. + inputs (torch.Tensor): A tensor containing input token ids. Returns: - dict[str, torch.Tensor]: A dictionary containing output tensors. - - prediction_key (str): Key for the output tensor containing logits. + torch.Tensor: A tensor containing output logits. """ - return self.forward_impl(inputs) + ... - -class GPT2LLMPP(GPT2LLM): - """GPT2LLM class.""" + def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor] | torch.Tensor: + if isinstance(inputs, dict): + return {self.prediction_key: self.forward_impl(inputs[self.sample_key])} + else: + return self.forward_impl(inputs) def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor: """ Forward pass implementation of the GPT2LLM module. Args: - inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. - - sample_key (str): Key for the input tensor containing token ids. + inputs (torch.Tensor): A tensor containing input token ids. Returns: - dict[str, torch.Tensor]: A dictionary containing output tensors. - - prediction_key (str): Key for the output tensor containing logits. + torch.Tensor: A tensor containing output logits. """ device = inputs.device - t = inputs.size(1) # batch size, sequence length - assert t <= self.sequence_length, f"Cannot forward sequence of length {t}, the model's maximum " - f"input sequence length is only {self.sequence_length}" + seq_len = inputs.size(1) + assert seq_len <= self.sequence_length, f"Cannot forward sequence of length {seq_len}, the model's maximum " + f"input sequence length is only {self.sequence_length}." # forward the GPT model itself h = ( self.transformer.wte(inputs) if hasattr(self.transformer, "wte") else inputs - ) # token embeddings of shape (b, t, n_embd) + ) # token embeddings of shape (b, seq_len, n_embd) if self.poe_type is PositionTypes.ABSOLUTE and hasattr(self.transformer, "wpe"): - pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) + pos = torch.arange(0, seq_len, dtype=torch.long, device=device) # shape (seq_len) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (seq_len, n_embd) h = h + pos_emb # TODO: use drop out also without absolute position embedding? @@ -970,20 +948,6 @@ def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor: h = self.transformer.lm_head(h) if hasattr(self.transformer, "lm_head") else h return h - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the GPT2LLM module. - - Args: - inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. - - sample_key (str): Key for the input tensor containing token ids. - - Returns: - dict[str, torch.Tensor]: A dictionary containing output tensors. - - prediction_key (str): Key for the output tensor containing logits. - """ - return self.forward_impl(inputs) - def manual_scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index dc3d084a8..877c9cbdc 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -28,7 +28,6 @@ from modalities.exceptions import ModelStateError from modalities.models.gpt2.gpt2_model import ( GPT2LLM, - GPT2LLMPP, AttentionConfig, AttentionImplementation, LayerNormWrapperConfig, @@ -569,7 +568,6 @@ def get_gpt2_model( use_weight_tying: bool, use_meta_device: Optional[bool] = False, seed: int = None, - use_pp: Optional[bool] = False, ) -> GPT2LLM: config = dict( sample_key=sample_key, @@ -599,12 +597,11 @@ def get_gpt2_model( "Please set at least use_meta_device=False or use_weight_tying=False." "https://github.com/Modalities/modalities/issues/357" ) - gpt2_model_class = GPT2LLMPP if use_pp else GPT2LLM if use_meta_device: with torch.device("meta"): - model = gpt2_model_class(**config) + model = GPT2LLM(**config) else: - model = gpt2_model_class(**config) + model = GPT2LLM(**config) return model @staticmethod diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index b3a0a7618..167a29894 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -78,7 +78,7 @@ ProgressSubscriberFactory, ResultsSubscriberFactory, ) -from modalities.loss_functions import CLMCrossEntropyLoss, CLMCrossEntropyLossPP +from modalities.loss_functions import CLMCrossEntropyLoss from modalities.models.coca.coca_model import CoCa, CoCaConfig from modalities.models.coca.collator import CoCaCollateFnConfig, CoCaCollatorFn from modalities.models.components.layer_norms import LayerNormConfig, RMSLayerNorm, RMSLayerNormConfig @@ -200,7 +200,6 @@ class ComponentEntity: ), # losses ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig), - ComponentEntity("loss", "clm_cross_entropy_loss_pp", CLMCrossEntropyLossPP, CLMCrossEntropyLossConfig), # optmizers ComponentEntity("optimizer", "adam", OptimizerFactory.get_adam, AdamOptimizerConfig), ComponentEntity("optimizer", "adam_w", OptimizerFactory.get_adam_w, AdamWOptimizerConfig), diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml index 5ef5f148a..0ceb02a53 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml @@ -16,7 +16,7 @@ settings: loss_fn: component_key: loss - variant_key: clm_cross_entropy_loss_pp + variant_key: clm_cross_entropy_loss config: target_key: ${settings.referencing_keys.target_key} prediction_key: ${settings.referencing_keys.prediction_key} @@ -129,7 +129,6 @@ model_raw: component_key: model variant_key: gpt2 config: - use_pp: true use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} From 5147a7ac98288f9c4163e7773b427f058b200ac3 Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Mon, 8 Sep 2025 12:31:43 +0200 Subject: [PATCH 17/60] refactor(parallelism): Clean up for pp test. --- .../test_pp_fwd_bwd_pass.py | 51 +++++++++++-------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py index 7384b8338..f933f4289 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -109,40 +109,34 @@ def _test_pp_impl( world_size=world_size, rdvz_port=22356, ): - components = self._get_components(pp_model_config_path, use_pp=True) - scheduled_pipeline = components.scheduled_pipeline vocab_size = 50304 sequence_length = 256 batch_size = 4 sequences = torch.randint(0, vocab_size, (batch_size, sequence_length)) targets = sequences[:, 1:].contiguous() inputs = sequences[:, :-1].contiguous() - loss_pp = self._forward_step(scheduled_pipeline, inputs, targets) - - # if scheduled_pipeline.is_last_pp_stage: - working_dir = Path(os.path.dirname(__file__)) - fsdp2_model_config_path = working_dir / "configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml" - fsdp2_components = self._get_components(fsdp2_model_config_path, use_pp=False) - fsdp2_model = fsdp2_components.initialized_model - fsdp2_loss_fn = fsdp2_components.loss_fn - - input_dict = {"input_ids": inputs} - fsdp2_out = fsdp2_model(input_dict) - forward_batch = InferenceResultBatch(predictions=fsdp2_out, targets={fsdp2_loss_fn.target_key: targets}) - fsdp2_loss = fsdp2_loss_fn(forward_batch) - if scheduled_pipeline.is_last_pp_stage: - assert torch.allclose(fsdp2_loss, loss_pp, atol=1e-6, rtol=1e-5), "Outputs do not match" + + is_last_pp_stage, loss_pp = self._forward_step_with_pp(pp_model_config_path, inputs, targets) + fsdp2_loss = self._forward_step_without_pp(inputs, targets) + + if is_last_pp_stage: + assert torch.allclose(loss_pp, fsdp2_loss, atol=1e-6, rtol=1e-5), "Losses do not match" + + def _forward_step_with_pp( + self, pp_model_config_path: Path, inputs: torch.Tensor, targets: torch.Tensor + ) -> tuple[bool, torch.Tensor]: + components = self._get_components(pp_model_config_path, use_pp=True) + scheduled_pipeline = components.scheduled_pipeline + loss_pp = self._forward_step(scheduled_pipeline, inputs, targets) + return scheduled_pipeline.is_last_pp_stage, loss_pp def _forward_step(self, scheduled_pipeline: Pipeline, inputs: torch.Tensor, targets: torch.Tensor): """Runs a forward step on the model.""" pp_schedule = scheduled_pipeline.pp_schedule targets, losses = (targets, []) if scheduled_pipeline.is_last_pp_stage else (None, None) - if scheduled_pipeline.is_first_pp_stage: # first stage - # pp_schedule.step(inputs, target=targets, losses=losses, input_batch=inputs) + if scheduled_pipeline.is_first_pp_stage: pp_schedule.step(inputs, target=targets, losses=losses) - else: # non-first stage - # pp_schedule.step(target=targets, losses=losses, input_batch=inputs) - # pp_schedule.step(inputs, target=targets, losses=losses, input_batch=inputs) + else: pp_schedule.step(target=targets, losses=losses) # accumulate losses across pipeline microbatches @@ -152,3 +146,16 @@ def _forward_step(self, scheduled_pipeline: Pipeline, inputs: torch.Tensor, targ if scheduled_pipeline.is_last_pp_stage else torch.tensor([-1.0], device=inputs.device) ) + + def _forward_step_without_pp(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + working_dir = Path(os.path.dirname(__file__)) + fsdp2_model_config_path = working_dir / "configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml" + fsdp2_components = self._get_components(fsdp2_model_config_path, use_pp=False) + fsdp2_model = fsdp2_components.initialized_model + fsdp2_loss_fn = fsdp2_components.loss_fn + + input_dict = {"input_ids": inputs} + fsdp2_out = fsdp2_model(input_dict) + forward_batch = InferenceResultBatch(predictions=fsdp2_out, targets={fsdp2_loss_fn.target_key: targets}) + fsdp2_loss = fsdp2_loss_fn(forward_batch) + return fsdp2_loss From 1cb977954daaea91b72dd878b4693a9d8d7dbd64 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 8 Sep 2025 13:28:34 +0200 Subject: [PATCH 18/60] test: Print losses to debug tests --- .../pipeline_parallelism/test_pp_fwd_bwd_pass.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py index f933f4289..b9bb462fd 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -120,6 +120,7 @@ def _test_pp_impl( fsdp2_loss = self._forward_step_without_pp(inputs, targets) if is_last_pp_stage: + print(f"Loss with PP: {loss_pp.item()}, Loss without PP: {fsdp2_loss.item()}") assert torch.allclose(loss_pp, fsdp2_loss, atol=1e-6, rtol=1e-5), "Losses do not match" def _forward_step_with_pp( From 27ad56dffd5fa5b6d1a065d937493ad1412796f6 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 9 Sep 2025 11:12:00 +0200 Subject: [PATCH 19/60] feat: Use scheduled_pipeline for forwad backward pass --- src/modalities/config/instantiation_models.py | 2 ++ src/modalities/main.py | 1 + 2 files changed, 3 insertions(+) diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index 4c57c133d..578641a64 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -13,6 +13,7 @@ PydanticLossIFType, PydanticMessageSubscriberIFType, PydanticMFUCalculatorABCType, + PydanticPipelineType, PydanticPytorchDeviceType, PydanticPytorchModuleType, PydanticTextInferenceComponentType, @@ -178,6 +179,7 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel checkpoint_saving: PydanticCheckpointSavingIFType gradient_clipper: PydanticGradientClipperIFType mfu_calculator: Optional[PydanticMFUCalculatorABCType] = None + scheduled_pipeline: Optional[PydanticPipelineType] = None model_raw: PydanticPytorchModuleType @model_validator(mode="after") diff --git a/src/modalities/main.py b/src/modalities/main.py index d995b9168..271a6759a 100644 --- a/src/modalities/main.py +++ b/src/modalities/main.py @@ -169,6 +169,7 @@ def run(self, components: TrainingComponentsInstantiationModel): checkpointing_interval_in_steps=components.settings.intervals.checkpointing_interval_in_steps, evaluation_interval_in_steps=components.settings.intervals.evaluation_interval_in_steps, training_log_interval_in_steps=components.settings.intervals.training_log_interval_in_steps, + scheduled_pipeline=components.scheduled_pipeline if components.scheduled_pipeline else None, ) def get_logging_publishers( From 41c4f36d3760e5213a77e2ec4a383203414323ac Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 9 Sep 2025 14:32:12 +0200 Subject: [PATCH 20/60] feat: Use scheduled_pipeline for training --- src/modalities/gym.py | 7 ++++++- src/modalities/trainer.py | 39 +++++++++++++++++++++++++++++++-------- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/modalities/gym.py b/src/modalities/gym.py index 0394b7a28..65b29fab8 100644 --- a/src/modalities/gym.py +++ b/src/modalities/gym.py @@ -40,6 +40,7 @@ def run( train_data_loader: LLMDataLoader, evaluation_data_loaders: list[LLMDataLoader], checkpoint_saving: CheckpointSaving, + scheduled_pipeline=None, # TODO set type ): """Runs the model training, including evaluation and checkpointing. @@ -57,6 +58,7 @@ def run( model=app_state.model, evaluation_data_loaders=evaluation_data_loaders, evaluation_interval_in_steps=evaluation_interval_in_steps, + scheduled_pipeline=scheduled_pipeline, ) checkpointing_callback: Callable[[TrainingProgress], None] = partial( @@ -74,6 +76,7 @@ def run( evaluation_callback=evaluation_callback, checkpointing_callback=checkpointing_callback, training_log_interval_in_steps=training_log_interval_in_steps, + scheduled_pipeline=scheduled_pipeline, ) print_rank_0(f"Training done at {datetime.now()}.") @@ -101,11 +104,13 @@ def _run_evaluation( num_train_steps_done: int, evaluation_data_loaders: list[LLMDataLoader], evaluation_interval_in_steps: int, + scheduled_pipeline=None, # TODO set type ): - if num_train_steps_done % evaluation_interval_in_steps == 0: + if num_train_steps_done % evaluation_interval_in_steps == 0 and num_train_steps_done > 10: self.evaluator.evaluate( model=model, data_loaders=evaluation_data_loaders, loss_fun=self.loss_fun, num_train_steps_done=num_train_steps_done, + scheduled_pipeline=scheduled_pipeline, ) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index b443c0ad3..55213cf9f 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -95,6 +95,7 @@ def _train_batch( scheduler: LRScheduler, loss_fun: Loss, micro_batch_id: int, + scheduled_pipeline=None, # TODO set type ) -> tuple[bool, int, torch.Tensor, Optional[torch.Tensor]]: """ Conducts a training step on batch of data. @@ -116,9 +117,27 @@ def _train_batch( - gradient_norm_score (Optional[torch.Tensor]): The gradient norm score, if a training step was performed otherwise return None. """ - result_batch = model_predict_batch(model=model, batch=batch) - loss = loss_fun(result_batch) - (loss / self.gradient_acc_steps).backward() + if scheduled_pipeline is not None: + pp_schedule = scheduled_pipeline.pp_schedule + # TODO: handle loss and backward in pp + # Pipeline Parallel forward / backward inside step() call + # with self.train_context(optional_context_parallel_ctx): + targets, losses = ( + (batch.targets[loss_fun.target_key].contiguous(), []) + if scheduled_pipeline.is_last_pp_stage + else (None, None) + ) + + if scheduled_pipeline.is_first_pp_stage: + pp_schedule.step(batch.samples[model.sample_key].contiguous(), target=targets, losses=losses) + else: + pp_schedule.step(target=targets, losses=losses) + loss = torch.mean(torch.stack(losses)).to(losses[0].device) if scheduled_pipeline.is_last_pp_stage else None + else: + # else continue with loss calculation + result_batch = model_predict_batch(model=model, batch=batch) + loss = loss_fun(result_batch) + (loss / self.gradient_acc_steps).backward() if (micro_batch_id + 1) % self.gradient_acc_steps == 0: gradient_norm_score = self.gradient_clipper.clip_gradients() @@ -143,6 +162,7 @@ def train( training_log_interval_in_steps: int, evaluation_callback: Callable[[TrainingProgress], None], checkpointing_callback: Callable[[TrainingProgress], None], + scheduled_pipeline=None, # TODO set type ): """ Trains the model. @@ -206,15 +226,17 @@ def train( scheduler=lr_scheduler, loss_fun=loss_fun, micro_batch_id=micro_batch_id, + scheduled_pipeline=scheduled_pipeline, ) forward_backward_time_recorder.stop() training_progress.num_seen_steps_current_run = num_train_steps_done training_progress.num_seen_tokens_current_run = self.global_num_tokens_per_train_step * num_train_steps_done - # Save the batch loss - cumulated_losses[0] += batch_loss.item() - # This works, because we always drop the last batch in case it has less samples than the batch size - cumulated_losses[-1] += 1 # number of local batches + if batch_loss is not None: + # Save the batch loss + cumulated_losses[0] += batch_loss.item() + # This works, because we always drop the last batch in case it has less samples than the batch size + cumulated_losses[-1] += 1 # number of local batches # gradient norm is already synced across all ranks if gradient_norm_score is not None: @@ -243,7 +265,8 @@ def train( synced_num_samples_per_second = synced_num_samples / synced_forward_backward_time # TODO: insert reducer from outside so Trainer is independent of FSDP # add the loss and gradient norm for the LAST batch - cumulated_losses[1] = batch_loss.item() + + cumulated_losses[1] = batch_loss.item() if batch_loss is not None else 0.0 reduced_losses = Reducer.reduce( tensor=cumulated_losses, From 6f3d5da3a11573c2017f5670218631303a48df26 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 9 Sep 2025 17:48:24 +0200 Subject: [PATCH 21/60] feat: Use scheduled_pipe in evaluation --- src/modalities/evaluator.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index 456fcb47f..5d56bb90c 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -36,6 +36,7 @@ def evaluate_batch( batch: DatasetBatch, model: nn.Module, loss_fun: Callable[[InferenceResultBatch], torch.Tensor], + scheduled_pipeline=None, # TODO set type ) -> torch.Tensor: """Evaluate a single batch by forwarding it through the model and calculating the loss. @@ -48,8 +49,26 @@ def evaluate_batch( torch.Tensor: The loss of the batch """ with torch.no_grad(): - result_batch = model_predict_batch(model=model, batch=batch) - loss = loss_fun(result_batch) + if scheduled_pipeline is not None: + pp_schedule = scheduled_pipeline.pp_schedule + targets, losses = ( + (batch.targets[loss_fun.target_key].contiguous(), []) + if scheduled_pipeline.is_last_pp_stage + else (None, None) + ) + + if scheduled_pipeline.is_first_pp_stage: + pp_schedule.eval(batch.samples[model.sample_key].contiguous(), target=targets, losses=losses) + else: + pp_schedule.eval(target=targets, losses=losses) + loss = ( + torch.mean(torch.stack(losses)).to(losses[0].device) + if scheduled_pipeline.is_last_pp_stage + else None + ) + else: + result_batch = model_predict_batch(model=model, batch=batch) + loss = loss_fun(result_batch) return loss def evaluate( @@ -58,6 +77,7 @@ def evaluate( data_loaders: list[LLMDataLoader], loss_fun: Callable[[InferenceResultBatch], torch.Tensor], num_train_steps_done: int, + scheduled_pipeline=None, # TODO set type ) -> dict[str, EvaluationResultBatch]: """Evaluate the model on a set of datasets. @@ -90,10 +110,12 @@ def evaluate( batch=batch, model=model, loss_fun=loss_fun, + scheduled_pipeline=scheduled_pipeline, ) - cumulated_loss[0] += batch_loss.item() # sum up batch loss - cumulated_loss[1] += 1 + if batch_loss is not None: + cumulated_loss[0] += batch_loss.item() # sum up batch loss + cumulated_loss[1] += 1 batch_length_tensor = torch.tensor(len(batch)).to(device) thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) From 9b853340f207b0c728d98af9be33249378726d58 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 9 Sep 2025 17:49:31 +0200 Subject: [PATCH 22/60] test: Print losses if test fails --- .../pipeline_parallelism/test_pp_fwd_bwd_pass.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py index b9bb462fd..47e1fe990 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -120,8 +120,9 @@ def _test_pp_impl( fsdp2_loss = self._forward_step_without_pp(inputs, targets) if is_last_pp_stage: - print(f"Loss with PP: {loss_pp.item()}, Loss without PP: {fsdp2_loss.item()}") - assert torch.allclose(loss_pp, fsdp2_loss, atol=1e-6, rtol=1e-5), "Losses do not match" + assert torch.allclose( + loss_pp, fsdp2_loss, atol=1e-6, rtol=1e-5 + ), f"Losses do not match.\nLoss with PP: {loss_pp.item()}, Loss without PP: {fsdp2_loss.item()}" def _forward_step_with_pp( self, pp_model_config_path: Path, inputs: torch.Tensor, targets: torch.Tensor From 84e2702627bc594fb1285ca8cc6f0d9a6af6a9f5 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 9 Sep 2025 17:50:03 +0200 Subject: [PATCH 23/60] chore: Run evaluation before training --- src/modalities/gym.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/gym.py b/src/modalities/gym.py index 65b29fab8..1e6fadf5a 100644 --- a/src/modalities/gym.py +++ b/src/modalities/gym.py @@ -106,7 +106,7 @@ def _run_evaluation( evaluation_interval_in_steps: int, scheduled_pipeline=None, # TODO set type ): - if num_train_steps_done % evaluation_interval_in_steps == 0 and num_train_steps_done > 10: + if num_train_steps_done % evaluation_interval_in_steps == 0: self.evaluator.evaluate( model=model, data_loaders=evaluation_data_loaders, From 32fbe9499c6a27c7f6e639cfdb8687207df0e4df Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 9 Sep 2025 17:50:55 +0200 Subject: [PATCH 24/60] chore: Increase microbatch size --- config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml index 381550a20..5d5557e6f 100644 --- a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml @@ -24,7 +24,7 @@ settings: enforce_last_step_checkpointed: false step_profile: gradient_accumulation_steps: 1 - local_train_micro_batch_size: 2 + local_train_micro_batch_size: 4 sequence_length: 256 training_target: num_target_tokens: @@ -222,7 +222,7 @@ scheduled_pipeline: pass_type: BY_REFERENCE pp_schedule_name: gpipe batch_size: ${settings.step_profile.local_train_micro_batch_size} - microbatch_size: 1 + microbatch_size: 2 pp_degree: ${device_mesh.config.pipeline_parallel_degree} # maybe better to use the fsdp model and the schedule here # instead of passing in the staged pipeline? From 61ab3114b4305890a642497c12bc6afe82e678f2 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Wed, 10 Sep 2025 17:24:36 +0200 Subject: [PATCH 25/60] fix: Use dp size instead of world size for last batch aggregation --- src/modalities/config/instantiation_models.py | 2 ++ src/modalities/main.py | 3 +++ .../running_env/fsdp/device_mesh.py | 22 +++++++++++++++++++ src/modalities/trainer.py | 6 ++++- 4 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index 578641a64..6e3b12d8a 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -8,6 +8,7 @@ PydanticAppStateType, PydanticCheckpointSavingIFType, PydanticDatasetIFType, + PydanticDeviceMeshIFType, PydanticGradientClipperIFType, PydanticLLMDataLoaderIFType, PydanticLossIFType, @@ -180,6 +181,7 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel gradient_clipper: PydanticGradientClipperIFType mfu_calculator: Optional[PydanticMFUCalculatorABCType] = None scheduled_pipeline: Optional[PydanticPipelineType] = None + device_mesh: PydanticDeviceMeshIFType model_raw: PydanticPytorchModuleType @model_validator(mode="after") diff --git a/src/modalities/main.py b/src/modalities/main.py index 271a6759a..f64ea16bf 100644 --- a/src/modalities/main.py +++ b/src/modalities/main.py @@ -20,6 +20,7 @@ from modalities.logging_broker.subscriber import MessageSubscriberIF from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry +from modalities.running_env.fsdp.device_mesh import get_num_data_parallel_ranks from modalities.trainer import Trainer from modalities.util import get_synced_experiment_id_of_run, get_total_number_of_trainable_parameters, print_rank_0 @@ -116,6 +117,7 @@ def run(self, components: TrainingComponentsInstantiationModel): * components.settings.step_profile.gradient_accumulation_steps * components.settings.cuda_env.world_size ) + num_data_parallel_ranks = get_num_data_parallel_ranks(components.device_mesh) trainer = Trainer( global_rank=components.settings.cuda_env.global_rank, progress_publisher=progress_publisher, @@ -128,6 +130,7 @@ def run(self, components: TrainingComponentsInstantiationModel): gradient_clipper=components.gradient_clipper, global_num_tokens_per_train_step=global_num_tokens_per_train_step, mfu_calculator=components.mfu_calculator, + num_data_parallel_ranks=num_data_parallel_ranks, ) # Evaluator diff --git a/src/modalities/running_env/fsdp/device_mesh.py b/src/modalities/running_env/fsdp/device_mesh.py index 24e7d6e18..c74751362 100644 --- a/src/modalities/running_env/fsdp/device_mesh.py +++ b/src/modalities/running_env/fsdp/device_mesh.py @@ -127,3 +127,25 @@ def get_device_mesh( # TODO: Torch Titan had some more checks here. We need to check if we also need those: # https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/distributed/parallel_dims.py#L86-L104 return device_mesh + + +def get_num_data_parallel_ranks(device_mesh: DeviceMesh) -> int: + """Gets the number of data parallel ranks from the device mesh. + + Args: + device_mesh (DeviceMesh): The device mesh. + + Returns: + int: The number of data parallel ranks. + """ + world_size = device_mesh.size() + dp_size = world_size + for parallelism_degree in ( + ParallelismDegrees.TP.value, + ParallelismDegrees.PP.value, + ParallelismDegrees.CP.value, + ): + if parallelism_degree in device_mesh.mesh_dim_names: + dp_size //= device_mesh.size(device_mesh.mesh_dim_names.index(parallelism_degree)) + + return dp_size diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 55213cf9f..9960920ba 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -30,6 +30,7 @@ class Trainer: def __init__( self, global_rank: int, + num_data_parallel_ranks: int, progress_publisher: MessagePublisher[ProgressUpdate], evaluation_result_publisher: MessagePublisher[EvaluationResultBatch], gradient_acc_steps: int, @@ -62,6 +63,7 @@ def __init__( None """ self.global_rank = global_rank + self.num_data_parallel_ranks = num_data_parallel_ranks self.progress_publisher = progress_publisher self.evaluation_result_publisher = evaluation_result_publisher self.gradient_acc_steps = gradient_acc_steps @@ -273,7 +275,9 @@ def train( operation=dist.ReduceOp.SUM, # 1.) summed batch loss / (num batches * world size) # 2.) last batch loss / world size - post_processing_fun=lambda t: torch.stack([t[0] / t[-1], t[1] / dist.get_world_size()]), + post_processing_fun=lambda t: torch.stack( + [t[0] / t[-1], t[1] / self.num_data_parallel_ranks, t[-1]] + ), ) train_loss_avg, train_loss_last_batch = ( From 6952bcc1156b85fb24e6b06fc6735584dce1be45 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Wed, 10 Sep 2025 17:25:09 +0200 Subject: [PATCH 26/60] docs: Add TODOs for later check --- src/modalities/util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/modalities/util.py b/src/modalities/util.py index eee5ff108..42b1360ac 100644 --- a/src/modalities/util.py +++ b/src/modalities/util.py @@ -186,9 +186,11 @@ def get_total_number_of_trainable_parameters(model: FSDPX) -> Number: # However, users can also provide their own sharding process groups (currently not supported in Modalities) # which would require to adapt the code. if model.sharding_strategy.name == "NO_SHARD": - sharding_factor = dist.get_world_size() + sharding_factor = dist.get_world_size() # TODO Check if we should use number of data parallel ranks instead if model.sharding_strategy.name == "HYBRID_SHARD": - sharding_factor = dist.get_world_size() // torch.cuda.device_count() + sharding_factor = ( + dist.get_world_size() // torch.cuda.device_count() + ) # TODO Check if we should use number of data parallel ranks instead elif model.sharding_strategy.name == "FULL_SHARD": sharding_factor = 1 total_num_params = total_num_params // sharding_factor From 90dbe51c22df6769ad196cb014b1fed67732f355 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Wed, 10 Sep 2025 17:33:13 +0200 Subject: [PATCH 27/60] fix: Train before evaluation so that pp is initialized for backwards --- src/modalities/gym.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/gym.py b/src/modalities/gym.py index 1e6fadf5a..30f7a820f 100644 --- a/src/modalities/gym.py +++ b/src/modalities/gym.py @@ -106,7 +106,7 @@ def _run_evaluation( evaluation_interval_in_steps: int, scheduled_pipeline=None, # TODO set type ): - if num_train_steps_done % evaluation_interval_in_steps == 0: + if num_train_steps_done > 0 and num_train_steps_done % evaluation_interval_in_steps == 0: self.evaluator.evaluate( model=model, data_loaders=evaluation_data_loaders, From 49df7d616ca43d36a4ff8198c03b23cb98556869 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 12 Sep 2025 14:28:40 +0200 Subject: [PATCH 28/60] fix: Add missing parameter seed to GPT2LLMConfig --- src/modalities/models/gpt2/gpt2_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index bdbd49913..1de776101 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -319,7 +319,7 @@ class GPT2LLMConfig(BaseModel): ffn_norm_config (LayerNormWrapperConfig): Config for normalization of the feed-forward network. lm_head_norm_config (LayerNormWrapperConfig): Config for normalization of the language model head. use_weight_tying (bool): Whether to use weight tying. - + seed (int, optional): The seed for random number generation. Defaults to None. """ sample_key: str @@ -344,6 +344,7 @@ class GPT2LLMConfig(BaseModel): ffn_norm_config: LayerNormWrapperConfig lm_head_norm_config: LayerNormWrapperConfig use_weight_tying: bool + seed: Optional[int] = None @model_validator(mode="after") def check_divisibility(self) -> "GPT2LLMConfig": From 7996a299609dc4b78d122be8581cef971119c742 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 15 Sep 2025 10:38:13 +0200 Subject: [PATCH 29/60] fix: Retrieve all PP ranks for gradient clipping --- .../fsdp_gradient_clipper.py | 54 ++++++++++++++++--- .../fsdp_gradient_clipper_config.py | 4 +- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py index f1adddfb3..d4b280a32 100644 --- a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py +++ b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py @@ -1,11 +1,15 @@ +import math from typing import Iterable, Optional import torch +from torch import distributed as dist +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import FSDPModule as FSDP2 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP1 from torch.distributed.tensor import DTensor from modalities.config.lookup_enum import LookupEnum +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF @@ -91,7 +95,13 @@ def clip_gradients(self) -> torch.Tensor: class FSDP2GradientClipper(GradientClipperIF): """The FSDP2GradientClipper class that is responsible for clipping the gradients of a model wrapped with FSDP.""" - def __init__(self, wrapped_model: FSDP2, max_norm: float, norm_type=GradientClippingMode) -> None: + def __init__( + self, + wrapped_model: FSDP2, + max_norm: float, + norm_type=GradientClippingMode, + device_mesh: Optional[DeviceMesh] = None, + ) -> None: """ Initialize the FSDP2GradientClipper object. @@ -106,6 +116,7 @@ def __init__(self, wrapped_model: FSDP2, max_norm: float, norm_type=GradientClip self.wrapped_model = wrapped_model self.max_norm = max_norm self.norm_type = norm_type + self.device_mesh = device_mesh @torch.no_grad() def clip_gradients(self) -> torch.Tensor: @@ -121,6 +132,7 @@ def clip_gradients(self) -> torch.Tensor: norm_type=self.norm_type.value, error_if_nonfinite=True, foreach=True, + device_mesh=self.device_mesh, ) return gradient_norm_score @@ -131,6 +143,7 @@ def clip_grad_norm_( norm_type: float = 2.0, error_if_nonfinite: bool = False, foreach: Optional[bool] = None, + device_mesh: Optional[DeviceMesh] = None, ) -> torch.Tensor: """ Clip the gradient norm of an iterable of parameters. @@ -138,10 +151,6 @@ def clip_grad_norm_( Gradient norm clipping requires computing the gradient norm over the entire model. `torch.nn.utils.clip_grad_norm_` only computes gradient norm along DP/FSDP/TP dimensions. - TODO: for pipeline parallelism, we need to implement it like here: - https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/distributed/utils.py#L245 - I removed all the code w.r.t. pipeline parallelism for now. - Args: parameters: an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float): max norm of the gradients @@ -154,6 +163,7 @@ def clip_grad_norm_( If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently fall back to the slow implementation for other device types. Default: ``None`` + device_mesh: device mesh Returns: Total norm of the parameter gradients (viewed as a single vector). @@ -172,11 +182,23 @@ def clip_grad_norm_( if isinstance(total_norm, DTensor): # Will reach here if any non-PP parallelism is used. # If only using PP, total_norm will be a local tensor. + total_norm = total_norm.full_tensor() - torch.nn.utils.clip_grads_with_norm_( - parameters=parameters, max_norm=max_norm, total_norm=total_norm, foreach=foreach + pp_mesh = ( + device_mesh[ParallelismDegrees.PP.value] + if device_mesh is not None and ParallelismDegrees.PP.value in device_mesh.mesh_dim_names + else None ) + if pp_mesh is not None: + if math.isinf(norm_type): + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) + else: + total_norm **= norm_type + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) + total_norm **= 1.0 / norm_type + + torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) return total_norm @@ -184,7 +206,9 @@ class FSDP2LoggingOnlyGradientClipper(GradientClipperIF): """The FSDP2LoggingOnlyGradientClipper class that is responsible for logging the gradient norms without actually clipping the gradients.""" - def __init__(self, wrapped_model: FSDP2, norm_type=GradientClippingMode) -> None: + def __init__( + self, wrapped_model: FSDP2, norm_type=GradientClippingMode, device_mesh: Optional[DeviceMesh] = None + ) -> None: """ Initialize the FSDP2LoggingOnlyGradientClipper. @@ -197,6 +221,7 @@ def __init__(self, wrapped_model: FSDP2, norm_type=GradientClippingMode) -> None """ self.wrapped_model = wrapped_model self.norm_type = norm_type + self.device_mesh = device_mesh @torch.no_grad() def clip_gradients(self) -> torch.Tensor: @@ -214,6 +239,19 @@ def clip_gradients(self) -> torch.Tensor: # Will reach here if any non-PP parallelism is used. # If only using PP, total_norm will be a local tensor. total_norm = total_norm.full_tensor() + + pp_mesh = ( + self.device_mesh[ParallelismDegrees.PP.value] + if self.device_mesh is not None and ParallelismDegrees.PP.value in self.device_mesh.mesh_dim_names + else None + ) + if pp_mesh is not None: + if math.isinf(self.norm_type.value): + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) + else: + total_norm **= self.norm_type.value + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) + total_norm **= 1.0 / self.norm_type.value return total_norm diff --git a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py index 4b4dd807d..500d954d8 100644 --- a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py +++ b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field -from modalities.config.pydantic_if_types import PydanticPytorchModuleType +from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticPytorchModuleType from modalities.training.gradient_clipping.fsdp_gradient_clipper import GradientClippingMode @@ -24,6 +24,7 @@ class FSDPGradientClipperConfig(BaseModel): max_norm: Annotated[float, Field(strict=True, gt=0)] norm_type: GradientClippingMode wrapped_model: PydanticPytorchModuleType + device_mesh: PydanticDeviceMeshIFType | None = None class FSDPDummyGradientClipperConfig(BaseModel): @@ -41,6 +42,7 @@ class FSDPDummyGradientClipperConfig(BaseModel): wrapped_model: PydanticPytorchModuleType norm_type: GradientClippingMode + device_mesh: PydanticDeviceMeshIFType | None = None class DummyGradientClipperConfig(BaseModel): From cbddcbc8089fac92f38630c5399612c7d5a2d185 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 15 Sep 2025 10:39:05 +0200 Subject: [PATCH 30/60] test: Add new parameter num_data_parallel_ranks to Trainer --- tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/conftest.py b/tests/conftest.py index bc92e004b..9bf289122 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -196,6 +196,7 @@ def trainer(progress_publisher_mock, gradient_clipper_mock): global_num_seen_tokens=0, num_target_tokens=100, num_target_steps=10, + num_data_parallel_ranks=int(os.getenv("WORLD_SIZE")), ) From 56a917aca238471aa2057ab50b36b92b1faf9dc7 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 15 Sep 2025 11:06:40 +0200 Subject: [PATCH 31/60] fix: Make FSDP1GradientClipperConfig independent of device_mesh --- src/modalities/registry/components.py | 14 ++++--- .../fsdp_gradient_clipper_config.py | 38 ++++++++++++++++++- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 167a29894..9a3a3c46a 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -112,8 +112,10 @@ ) from modalities.training.gradient_clipping.fsdp_gradient_clipper_config import ( DummyGradientClipperConfig, - FSDPDummyGradientClipperConfig, - FSDPGradientClipperConfig, + FSDP1DummyGradientClipperConfig, + FSDP1GradientClipperConfig, + FSDP2DummyGradientClipperConfig, + FSDP2GradientClipperConfig, ) from modalities.utils.mfu import GPT2MFUCalculator from modalities.utils.number_conversion import ( @@ -325,13 +327,13 @@ class ComponentEntity: ComponentEntity("layer_norm", "rms_norm", RMSLayerNorm, RMSLayerNormConfig), ComponentEntity("layer_norm", "layer_norm", nn.LayerNorm, LayerNormConfig), # gradient clippers - ComponentEntity("gradient_clipper", "fsdp1", FSDP1GradientClipper, FSDPGradientClipperConfig), + ComponentEntity("gradient_clipper", "fsdp1", FSDP1GradientClipper, FSDP1GradientClipperConfig), ComponentEntity( - "gradient_clipper", "fsdp1_logging_only", FSDP1LoggingOnlyGradientClipper, FSDPDummyGradientClipperConfig + "gradient_clipper", "fsdp1_logging_only", FSDP1LoggingOnlyGradientClipper, FSDP1DummyGradientClipperConfig ), - ComponentEntity("gradient_clipper", "fsdp2", FSDP2GradientClipper, FSDPGradientClipperConfig), + ComponentEntity("gradient_clipper", "fsdp2", FSDP2GradientClipper, FSDP2GradientClipperConfig), ComponentEntity( - "gradient_clipper", "fsdp2_logging_only", FSDP2LoggingOnlyGradientClipper, FSDPDummyGradientClipperConfig + "gradient_clipper", "fsdp2_logging_only", FSDP2LoggingOnlyGradientClipper, FSDP2DummyGradientClipperConfig ), ComponentEntity("gradient_clipper", "dummy", DummyGradientClipper, DummyGradientClipperConfig), # MFU calculators diff --git a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py index 500d954d8..80ebee2a8 100644 --- a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py +++ b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py @@ -6,7 +6,7 @@ from modalities.training.gradient_clipping.fsdp_gradient_clipper import GradientClippingMode -class FSDPGradientClipperConfig(BaseModel): +class FSDP1GradientClipperConfig(BaseModel): """ Configuration class for FSDP gradient clipper. @@ -24,10 +24,44 @@ class FSDPGradientClipperConfig(BaseModel): max_norm: Annotated[float, Field(strict=True, gt=0)] norm_type: GradientClippingMode wrapped_model: PydanticPytorchModuleType + + +class FSDP2GradientClipperConfig(FSDP1GradientClipperConfig): + """ + Configuration class for FSDP gradient clipper. + + Args: + max_norm (float): The maximum norm value for gradient clipping. + norm_type (GradientClippingMode): The type of gradient clipping to be applied. + wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. + + Attributes: + max_norm (float): The maximum norm value for gradient clipping. + norm_type (GradientClippingMode): The type of gradient clipping to be applied. + wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. + """ + device_mesh: PydanticDeviceMeshIFType | None = None -class FSDPDummyGradientClipperConfig(BaseModel): +class FSDP1DummyGradientClipperConfig(BaseModel): + """ + Configuration class for FSDP dummy gradient clipper. + + Args: + wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. + norm_type (GradientClippingMode): The type of gradient clipping to be applied. + + Attributes: + wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. + norm_type (GradientClippingMode): The type of gradient clipping to be applied. + """ + + wrapped_model: PydanticPytorchModuleType + norm_type: GradientClippingMode + + +class FSDP2DummyGradientClipperConfig(FSDP1DummyGradientClipperConfig): """ Configuration class for FSDP dummy gradient clipper. From eb47aa965f99a5cc02dafdb2ec7da6b879a9086d Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 15 Sep 2025 11:07:13 +0200 Subject: [PATCH 32/60] fix: Handle optional device_mesh correctly --- src/modalities/config/instantiation_models.py | 2 +- src/modalities/main.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index 6e3b12d8a..c4bb105e1 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -181,7 +181,7 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel gradient_clipper: PydanticGradientClipperIFType mfu_calculator: Optional[PydanticMFUCalculatorABCType] = None scheduled_pipeline: Optional[PydanticPipelineType] = None - device_mesh: PydanticDeviceMeshIFType + device_mesh: Optional[PydanticDeviceMeshIFType] = None model_raw: PydanticPytorchModuleType @model_validator(mode="after") diff --git a/src/modalities/main.py b/src/modalities/main.py index f64ea16bf..9836a4d8e 100644 --- a/src/modalities/main.py +++ b/src/modalities/main.py @@ -117,7 +117,10 @@ def run(self, components: TrainingComponentsInstantiationModel): * components.settings.step_profile.gradient_accumulation_steps * components.settings.cuda_env.world_size ) - num_data_parallel_ranks = get_num_data_parallel_ranks(components.device_mesh) + if components.device_mesh is None: + num_data_parallel_ranks = 1 + else: + num_data_parallel_ranks = get_num_data_parallel_ranks(components.device_mesh) trainer = Trainer( global_rank=components.settings.cuda_env.global_rank, progress_publisher=progress_publisher, From d228351f7e400f7966adbc33ddc1c9d581a497d8 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Wed, 17 Sep 2025 10:36:18 +0200 Subject: [PATCH 33/60] feat: Consider pipeline parallelism in tensor pallelization --- src/modalities/models/model_factory.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 877c9cbdc..95c72c338 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -631,7 +631,7 @@ def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh) ), } - if isinstance(model.transformer.wpe, nn.Embedding): + if hasattr(model.transformer, "wpe") and isinstance(model.transformer.wpe, nn.Embedding): # If the position embedding is an nn.Embedding, we can shard it on the sequence dimension # to enable sequence parallelism in the downstream transformer blocks. # Note, for RoPE the wpe layer is an identity operation, which cannnot be sharded. @@ -640,11 +640,14 @@ def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh) output_layouts=Shard(0), ) - parallelize_module( - module=model, - device_mesh=tp_mesh, - parallelize_plan=model_tp_plan, - ) + # only keep the relevant parts of the model parallel plan + model_tp_plan = {k: v for k, v in model_tp_plan.items() if hasattr(model.transformer, k.split(".")[1])} + if model_tp_plan: + parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan=model_tp_plan, + ) transformer_block_tp_plan = { "attention_norm": SequenceParallel(), @@ -703,6 +706,16 @@ def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh) ) transformer_block.attn.n_head_q = transformer_block.attn.n_head_q // tp_mesh.size() transformer_block.attn.n_head_kv = transformer_block.attn.n_head_kv // tp_mesh.size() + # only keep the relevant parts of the model parallel plan + transformer_block_tp_plan = { + k: v + for k, v in transformer_block_tp_plan.items() + if ( + hasattr(transformer_block, k) + or hasattr(transformer_block.attn, k.split(".")[1]) + or hasattr(transformer_block.mlp, k.split(".")[1]) + ) + } parallelize_module( module=transformer_block, device_mesh=tp_mesh, From 55dad72bf973fef5c20d113851165a3b6e4a445a Mon Sep 17 00:00:00 2001 From: rrutmann Date: Wed, 17 Sep 2025 10:58:57 +0200 Subject: [PATCH 34/60] test: Use the same data on each rank & test tensor parallelism --- .../test_pp_fwd_bwd_pass.py | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py index 47e1fe990..9014f164e 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -30,12 +30,11 @@ def temp_file_path() -> Path: class ComponentsInstantiationPPModel(BaseModel): - initialized_model: PydanticFSDP2ModuleType scheduled_pipeline: PydanticPipelineType class ComponentsInstantiationModel(BaseModel): - initialized_model: PydanticFSDP2ModuleType + fsdp_model: PydanticFSDP2ModuleType loss_fn: PydanticLossIFType @@ -48,7 +47,10 @@ def _get_tmp_sharding_config_path( self, sharding_degree: int, tp_degree: int, pp_degree: int, temp_file_path: Path ) -> Path: working_dir = Path(os.path.dirname(__file__)) - config_file_path = working_dir / "configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml" + if tp_degree > 1: + config_file_path = working_dir / "configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml" + else: + config_file_path = working_dir / "configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml" with open(config_file_path, "r") as file: config_string = file.read() @@ -76,9 +78,9 @@ def _get_components(self, config_file_path: Path, use_pp: bool) -> ComponentsIns @pytest.mark.parametrize( "sharding_degree, tp_degree, pp_degree, world_size", [ - (2, 1, 2, 4), + # (2, 1, 2, 4), # (2, 1, 4, 8), - # (2, 2, 2, 8), # TODO need to support this case + (2, 2, 2, 8), # TODO need to support this case ], ) def test_pp(self, sharding_degree: int, tp_degree: int, pp_degree: int, world_size: int, temp_file_path: Path): @@ -107,11 +109,12 @@ def _test_pp_impl( global_rank=process_id, local_rank=process_id, world_size=world_size, - rdvz_port=22356, + rdvz_port=22359, ): vocab_size = 50304 - sequence_length = 256 + sequence_length = 4 batch_size = 4 + torch.manual_seed(42) sequences = torch.randint(0, vocab_size, (batch_size, sequence_length)) targets = sequences[:, 1:].contiguous() inputs = sequences[:, :-1].contiguous() @@ -127,13 +130,21 @@ def _test_pp_impl( def _forward_step_with_pp( self, pp_model_config_path: Path, inputs: torch.Tensor, targets: torch.Tensor ) -> tuple[bool, torch.Tensor]: - components = self._get_components(pp_model_config_path, use_pp=True) - scheduled_pipeline = components.scheduled_pipeline - loss_pp = self._forward_step(scheduled_pipeline, inputs, targets) + try: + components = self._get_components(pp_model_config_path, use_pp=True) + scheduled_pipeline = components.scheduled_pipeline + loss_pp = self._forward_step(scheduled_pipeline, inputs, targets) + except Exception as e: + import traceback + + print(f"Exception in _forward_step_with_pp: {e}") + traceback.print_exc() # <-- Add this line to print the full stack trace + raise e return scheduled_pipeline.is_last_pp_stage, loss_pp def _forward_step(self, scheduled_pipeline: Pipeline, inputs: torch.Tensor, targets: torch.Tensor): """Runs a forward step on the model.""" + os.environ["MODEL_TYPE"] = "PP" pp_schedule = scheduled_pipeline.pp_schedule targets, losses = (targets, []) if scheduled_pipeline.is_last_pp_stage else (None, None) if scheduled_pipeline.is_first_pp_stage: @@ -150,10 +161,11 @@ def _forward_step(self, scheduled_pipeline: Pipeline, inputs: torch.Tensor, targ ) def _forward_step_without_pp(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + os.environ["MODEL_TYPE"] = "NOPP" working_dir = Path(os.path.dirname(__file__)) fsdp2_model_config_path = working_dir / "configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml" fsdp2_components = self._get_components(fsdp2_model_config_path, use_pp=False) - fsdp2_model = fsdp2_components.initialized_model + fsdp2_model = fsdp2_components.fsdp_model fsdp2_loss_fn = fsdp2_components.loss_fn input_dict = {"input_ids": inputs} From b6a1e2d87968bfeb186027797c7ae0ee2caf906c Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Wed, 17 Sep 2025 19:47:25 +0200 Subject: [PATCH 35/60] refactor(parallelism): Some clean-up. --- src/modalities/evaluator.py | 15 ++++++++--- src/modalities/gym.py | 7 +++-- src/modalities/models/gpt2/gpt2_model.py | 4 +-- src/modalities/models/model_factory.py | 13 +++++----- src/modalities/trainer.py | 15 ++++++++--- .../test_pp_fwd_bwd_pass.py | 26 +++++++------------ 6 files changed, 45 insertions(+), 35 deletions(-) diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index 5d56bb90c..3f9f8f343 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -9,6 +9,7 @@ from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate from modalities.logging_broker.publisher import MessagePublisher from modalities.models.model import model_predict_batch +from modalities.models.parallelism.pipeline_parallelism import Pipeline from modalities.running_env.fsdp.reducer import Reducer from modalities.trainer import ThroughputAggregationKeys from modalities.util import Aggregator, TimeRecorder @@ -36,17 +37,20 @@ def evaluate_batch( batch: DatasetBatch, model: nn.Module, loss_fun: Callable[[InferenceResultBatch], torch.Tensor], - scheduled_pipeline=None, # TODO set type - ) -> torch.Tensor: + scheduled_pipeline: Pipeline | None = None, + ) -> torch.Tensor | None: """Evaluate a single batch by forwarding it through the model and calculating the loss. Args: batch (DatasetBatch): The batch to evaluate model (nn.Module): The model to evaluate loss_fun (Callable[[InferenceResultBatch], torch.Tensor]): The loss function to calculate the loss + scheduled_pipeline (Pipeline | None, optional): In case of pipeline parallelism, this is used to + operate the model. Defaults to None. Returns: - torch.Tensor: The loss of the batch + torch.Tensor | None: The loss of the batch + None, if a non-last stage was processed in pipeline parallelism """ with torch.no_grad(): if scheduled_pipeline is not None: @@ -77,7 +81,7 @@ def evaluate( data_loaders: list[LLMDataLoader], loss_fun: Callable[[InferenceResultBatch], torch.Tensor], num_train_steps_done: int, - scheduled_pipeline=None, # TODO set type + scheduled_pipeline: Pipeline | None = None, ) -> dict[str, EvaluationResultBatch]: """Evaluate the model on a set of datasets. @@ -86,6 +90,8 @@ def evaluate( data_loaders (list[LLMDataLoader]): List of dataloaders to evaluate the model on loss_fun (Callable[[InferenceResultBatch], torch.Tensor]): The loss function to calculate the loss num_train_steps_done (int): The number of training steps done so far for logging purposes + scheduled_pipeline (Pipeline | None, optional): In case of pipeline parallelism, this is used to + operate the model. Defaults to None. Returns: dict[str, EvaluationResultBatch]: A dictionary containing the evaluation results for each dataloader @@ -113,6 +119,7 @@ def evaluate( scheduled_pipeline=scheduled_pipeline, ) + # The batch_loss might be None if we use pipeline parallelism and are not the last stage. if batch_loss is not None: cumulated_loss[0] += batch_loss.item() # sum up batch loss cumulated_loss[1] += 1 diff --git a/src/modalities/gym.py b/src/modalities/gym.py index 30f7a820f..7ea5e660f 100644 --- a/src/modalities/gym.py +++ b/src/modalities/gym.py @@ -9,6 +9,7 @@ from modalities.dataloader.dataloader import LLMDataLoader from modalities.evaluator import Evaluator from modalities.loss_functions import Loss +from modalities.models.parallelism.pipeline_parallelism import Pipeline from modalities.trainer import Trainer from modalities.training.training_progress import TrainingProgress from modalities.util import print_rank_0 @@ -40,7 +41,7 @@ def run( train_data_loader: LLMDataLoader, evaluation_data_loaders: list[LLMDataLoader], checkpoint_saving: CheckpointSaving, - scheduled_pipeline=None, # TODO set type + scheduled_pipeline: Pipeline | None = None, ): """Runs the model training, including evaluation and checkpointing. @@ -52,6 +53,8 @@ def run( train_data_loader (LLMDataLoader): Data loader with the training data. evaluation_data_loaders (list[LLMDataLoader]): List of data loaders with the evaluation data. checkpoint_saving (CheckpointSaving): Routine for saving checkpoints. + scheduled_pipeline (Pipeline | None, optional): In case of pipeline parallelism, this is used to + operate the model. Defaults to None. """ evaluation_callback: Callable[[int], None] = partial( self._run_evaluation, @@ -104,7 +107,7 @@ def _run_evaluation( num_train_steps_done: int, evaluation_data_loaders: list[LLMDataLoader], evaluation_interval_in_steps: int, - scheduled_pipeline=None, # TODO set type + scheduled_pipeline: Pipeline | None = None, ): if num_train_steps_done > 0 and num_train_steps_done % evaluation_interval_in_steps == 0: self.evaluator.evaluate( diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 1de776101..168c6ae26 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -781,7 +781,7 @@ def __init__( ffn_norm_config: LayerNormWrapperConfig, lm_head_norm_config: LayerNormWrapperConfig, use_weight_tying: bool, - seed: int = None, + seed: int | None = None, ): """ Initializes the GPT2LLM object. @@ -805,8 +805,8 @@ def __init__( attention_norm_config (LayerNormWrapperConfig): Config for the attention normalization module. ffn_norm_config (LayerNormWrapperConfig): Config for the feed-forward network normalization module. lm_head_norm_config (LayerNormWrapperConfig): Config for the language model head normalization module. - seed (int, optional): The random seed. Defaults to None. use_weight_tying (bool): Whether to use weight tying. + seed (int, optional): The random seed. Defaults to None. """ weight_decay_groups = { "linear": [".attn", ".mlp", ".lm_head.weight"], diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 95c72c338..d463161f2 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -567,7 +567,7 @@ def get_gpt2_model( lm_head_norm_config: LayerNormWrapperConfig, use_weight_tying: bool, use_meta_device: Optional[bool] = False, - seed: int = None, + seed: int | None = None, ) -> GPT2LLM: config = dict( sample_key=sample_key, @@ -716,10 +716,11 @@ def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh) or hasattr(transformer_block.mlp, k.split(".")[1]) ) } - parallelize_module( - module=transformer_block, - device_mesh=tp_mesh, - parallelize_plan=transformer_block_tp_plan, - ) + if transformer_block_tp_plan: + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=transformer_block_tp_plan, + ) return model diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 9960920ba..c60dd8542 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -14,6 +14,7 @@ from modalities.logging_broker.publisher import MessagePublisher from modalities.loss_functions import Loss from modalities.models.model import model_predict_batch +from modalities.models.parallelism.pipeline_parallelism import Pipeline from modalities.running_env.fsdp.reducer import Reducer from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF from modalities.training.training_progress import TrainingProgress @@ -97,8 +98,8 @@ def _train_batch( scheduler: LRScheduler, loss_fun: Loss, micro_batch_id: int, - scheduled_pipeline=None, # TODO set type - ) -> tuple[bool, int, torch.Tensor, Optional[torch.Tensor]]: + scheduled_pipeline: Optional[Pipeline] = None, + ) -> tuple[bool, int, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Conducts a training step on batch of data. @@ -109,13 +110,16 @@ def _train_batch( scheduler (LRScheduler): The learning rate scheduler. loss_fun (Loss): The loss function used for training. micro_batch_id (int): The ID of the micro batch. + scheduled_pipeline (Optional[Pipeline], optional): In case of pipeline parallelism, this is used to + operate the model. Defaults to None. Returns: tuple[bool, int, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple containing the following: - step_performed (bool): Indicates whether a training step was performed. - num_train_steps_done (int): The number of training steps done. - - loss (torch.Tensor): The computed loss. + - loss (Optional[torch.Tensor]): The computed loss. + None, if a non-last stage was processes in pipeline parallelism. - gradient_norm_score (Optional[torch.Tensor]): The gradient norm score, if a training step was performed otherwise return None. """ @@ -164,7 +168,7 @@ def train( training_log_interval_in_steps: int, evaluation_callback: Callable[[TrainingProgress], None], checkpointing_callback: Callable[[TrainingProgress], None], - scheduled_pipeline=None, # TODO set type + scheduled_pipeline: Pipeline | None = None, ): """ Trains the model. @@ -176,6 +180,8 @@ def train( training_log_interval_in_steps (int): The interval at which training progress is logged. evaluation_callback (Callable[[TrainingProgress], None]): A callback function for evaluation. checkpointing_callback (Callable[[TrainingProgress], None]): A callback function for checkpointing. + scheduled_pipeline (Pipeline | None, optional): In case of pipeline parallelism, this is used to + operate the model. Defaults to None. Returns: None @@ -234,6 +240,7 @@ def train( training_progress.num_seen_steps_current_run = num_train_steps_done training_progress.num_seen_tokens_current_run = self.global_num_tokens_per_train_step * num_train_steps_done + # The batch_loss might be None if we use pipeline parallelism and are not the last stage. if batch_loss is not None: # Save the batch loss cumulated_losses[0] += batch_loss.item() diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py index 47e1fe990..73eb5863e 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -18,15 +18,8 @@ @pytest.fixture def temp_file_path() -> Path: - # Create a NamedTemporaryFile that persists after closing (delete=False) - with tempfile.NamedTemporaryFile(delete=False) as tf: - file_path = tf.name - try: - yield Path(file_path) - finally: - # Clean up the file after the test - if os.path.exists(file_path): - os.remove(file_path) + with tempfile.NamedTemporaryFile() as tf: + yield tf.name class ComponentsInstantiationPPModel(BaseModel): @@ -63,14 +56,14 @@ def _get_tmp_sharding_config_path( return temp_file_path - def _get_components(self, config_file_path: Path, use_pp: bool) -> ComponentsInstantiationPPModel: + def _get_components( + self, config_file_path: Path, use_pp: bool + ) -> ComponentsInstantiationPPModel | ComponentsInstantiationModel: torch.manual_seed(42) main_obj = Main(config_file_path) - if use_pp: - components_model_type = ComponentsInstantiationPPModel - else: - components_model_type = ComponentsInstantiationModel - components: components_model_type = main_obj.build_components(components_model_type=components_model_type) + components_model_type = ComponentsInstantiationPPModel if use_pp else ComponentsInstantiationModel + components = main_obj.build_components(components_model_type=components_model_type) + assert isinstance(components, components_model_type) return components @pytest.mark.parametrize( @@ -141,8 +134,7 @@ def _forward_step(self, scheduled_pipeline: Pipeline, inputs: torch.Tensor, targ else: pp_schedule.step(target=targets, losses=losses) - # accumulate losses across pipeline microbatches - # TODO: PP+FSDP unexpectedly puts the loss back to the CPU + # accumulate losses across pipeline microbatchess return ( torch.mean(torch.stack(losses)).to(losses[0].device) if scheduled_pipeline.is_last_pp_stage From c49895a987467475a90033c98d6b97ad256ff27d Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 19 Sep 2025 11:34:45 +0200 Subject: [PATCH 36/60] test: Update configs for parallelization testing --- ...g_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml | 34 ++-- ...orem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml | 41 ++-- ...m_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml | 177 ++++++++++++++++++ 3 files changed, 212 insertions(+), 40 deletions(-) create mode 100644 tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml index 6603b1850..bdf991173 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml @@ -29,13 +29,28 @@ device_mesh: data_parallel_replicate_degree: 1 data_parallel_shard_degree: -1 world_size: ${settings.cuda_env.world_size} + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] initialized_model: component_key: model variant_key: model_initialized config: model: - instance_key: fsdp_model + instance_key: model_raw pass_type: BY_REFERENCE model_initializer: component_key: model_initialization @@ -44,23 +59,8 @@ initialized_model: model_type: gpt2 weight_init_type: scaled mean: 0.0 - std: 0.02 + std: 1.02 num_layers: ${model_raw.config.n_layer} - -fsdp_model: - component_key: model - variant_key: fsdp2_wrapped - config: - model: - instance_key: model_raw - pass_type: BY_REFERENCE - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - mixed_precision_settings: - param_dtype: BF_16 - reduce_dtype: BF_16 - block_names: [GPT2Block] model_raw: component_key: model diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml index 0ceb02a53..2ffbe4cc5 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml @@ -31,28 +31,6 @@ device_mesh: data_parallel_shard_degree: -1 world_size: ${settings.cuda_env.world_size} -initialized_model: - component_key: model - variant_key: model_initialized - config: - model: - component_key: pipeline - variant_key: selector - config: - pipeline: - instance_key: scheduled_pipeline - pass_type: BY_REFERENCE - selection_type: MODEL_PART - model_initializer: - component_key: model_initialization - variant_key: composed - config: - model_type: gpt2 - weight_init_type: scaled - mean: 0.0 - std: 0.02 - num_layers: ${model_raw.config.n_layer} - scheduled_pipeline: component_key: pipeline variant_key: scheduled @@ -109,7 +87,7 @@ staged_pipeline: variant_key: staged config: whole_model: - instance_key: model_raw + instance_key: initialized_model pass_type: BY_REFERENCE stages_generator: component_key: stages_generator @@ -125,6 +103,23 @@ staged_pipeline: pp_schedule_name: gpipe num_layers_per_stage: 4 +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 1.02 + num_layers: ${model_raw.config.n_layer} + model_raw: component_key: model variant_key: gpt2 diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml new file mode 100644 index 000000000..fb8ee5f7d --- /dev/null +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml @@ -0,0 +1,177 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 4 + sequence_length: 256 + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 2 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: gpt2_tp_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +gpt2_tp_model: + component_key: model + variant_key: gpt2_tp + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: gpipe + num_layers_per_stage: 4 + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 6 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + From f685fc5d88b9a524dd7a18609ec6341946487f67 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 19 Sep 2025 12:43:22 +0200 Subject: [PATCH 37/60] test: Use correct length to create test sequences --- .../pipeline_parallelism/test_pp_fwd_bwd_pass.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py index 8906c5f86..dabd731b9 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -71,9 +71,9 @@ def _get_components( @pytest.mark.parametrize( "sharding_degree, tp_degree, pp_degree, world_size", [ - # (2, 1, 2, 4), - # (2, 1, 4, 8), - (2, 2, 2, 8), # TODO need to support this case + (2, 1, 2, 4), + (2, 1, 4, 8), + (2, 2, 2, 8), ], ) def test_pp(self, sharding_degree: int, tp_degree: int, pp_degree: int, world_size: int, temp_file_path: Path): @@ -105,10 +105,10 @@ def _test_pp_impl( rdvz_port=22359, ): vocab_size = 50304 - sequence_length = 4 + sequence_length = 256 batch_size = 4 torch.manual_seed(42) - sequences = torch.randint(0, vocab_size, (batch_size, sequence_length)) + sequences = torch.randint(0, vocab_size, (batch_size, sequence_length + 1)) targets = sequences[:, 1:].contiguous() inputs = sequences[:, :-1].contiguous() From c07fcf6b2a077aa5ddc5d13d7cc5fedd39486fd2 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 19 Sep 2025 12:45:10 +0200 Subject: [PATCH 38/60] test: Use realistic std for model initialization --- .../configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml | 2 +- .../configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml index bdf991173..988e70eba 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml @@ -59,7 +59,7 @@ initialized_model: model_type: gpt2 weight_init_type: scaled mean: 0.0 - std: 1.02 + std: 0.02 num_layers: ${model_raw.config.n_layer} model_raw: diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml index 2ffbe4cc5..f41e912bc 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml @@ -117,7 +117,7 @@ initialized_model: model_type: gpt2 weight_init_type: scaled mean: 0.0 - std: 1.02 + std: 0.02 num_layers: ${model_raw.config.n_layer} model_raw: From 5019bbb869621b09d492cb1deeca1ed6473662b5 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 19 Sep 2025 12:45:53 +0200 Subject: [PATCH 39/60] fix: Remove unused third dimension for reduced_losses --- src/modalities/trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index c60dd8542..536f98409 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -282,9 +282,7 @@ def train( operation=dist.ReduceOp.SUM, # 1.) summed batch loss / (num batches * world size) # 2.) last batch loss / world size - post_processing_fun=lambda t: torch.stack( - [t[0] / t[-1], t[1] / self.num_data_parallel_ranks, t[-1]] - ), + post_processing_fun=lambda t: torch.stack([t[0] / t[-1], t[1] / self.num_data_parallel_ranks]), ) train_loss_avg, train_loss_last_batch = ( From a08e555af1db303efc45c0f33144005b39e64e64 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 19 Sep 2025 12:46:36 +0200 Subject: [PATCH 40/60] refactor: Remove unused filtering --- src/modalities/models/model_factory.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index d463161f2..a96a9bc2d 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -706,21 +706,10 @@ def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh) ) transformer_block.attn.n_head_q = transformer_block.attn.n_head_q // tp_mesh.size() transformer_block.attn.n_head_kv = transformer_block.attn.n_head_kv // tp_mesh.size() - # only keep the relevant parts of the model parallel plan - transformer_block_tp_plan = { - k: v - for k, v in transformer_block_tp_plan.items() - if ( - hasattr(transformer_block, k) - or hasattr(transformer_block.attn, k.split(".")[1]) - or hasattr(transformer_block.mlp, k.split(".")[1]) - ) - } - if transformer_block_tp_plan: - parallelize_module( - module=transformer_block, - device_mesh=tp_mesh, - parallelize_plan=transformer_block_tp_plan, - ) + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=transformer_block_tp_plan, + ) return model From 45b54188a97c0fd7c0ffcca84e41c0176cec51a0 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 22 Sep 2025 14:44:39 +0200 Subject: [PATCH 41/60] fix: Aggregate loss of last train batch correct across pp ranks --- src/modalities/main.py | 8 ++++---- .../running_env/fsdp/device_mesh.py | 19 +++++-------------- src/modalities/trainer.py | 10 ++++++---- tests/conftest.py | 2 +- 4 files changed, 16 insertions(+), 23 deletions(-) diff --git a/src/modalities/main.py b/src/modalities/main.py index 9836a4d8e..2f680cf71 100644 --- a/src/modalities/main.py +++ b/src/modalities/main.py @@ -20,7 +20,7 @@ from modalities.logging_broker.subscriber import MessageSubscriberIF from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry -from modalities.running_env.fsdp.device_mesh import get_num_data_parallel_ranks +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_num_parallel_ranks from modalities.trainer import Trainer from modalities.util import get_synced_experiment_id_of_run, get_total_number_of_trainable_parameters, print_rank_0 @@ -118,9 +118,9 @@ def run(self, components: TrainingComponentsInstantiationModel): * components.settings.cuda_env.world_size ) if components.device_mesh is None: - num_data_parallel_ranks = 1 + num_pipeline_parallel_ranks = 1 else: - num_data_parallel_ranks = get_num_data_parallel_ranks(components.device_mesh) + num_pipeline_parallel_ranks = get_num_parallel_ranks(components.device_mesh, ParallelismDegrees.PP) trainer = Trainer( global_rank=components.settings.cuda_env.global_rank, progress_publisher=progress_publisher, @@ -133,7 +133,7 @@ def run(self, components: TrainingComponentsInstantiationModel): gradient_clipper=components.gradient_clipper, global_num_tokens_per_train_step=global_num_tokens_per_train_step, mfu_calculator=components.mfu_calculator, - num_data_parallel_ranks=num_data_parallel_ranks, + num_pipeline_parallel_ranks=num_pipeline_parallel_ranks, ) # Evaluator diff --git a/src/modalities/running_env/fsdp/device_mesh.py b/src/modalities/running_env/fsdp/device_mesh.py index c74751362..4770951f0 100644 --- a/src/modalities/running_env/fsdp/device_mesh.py +++ b/src/modalities/running_env/fsdp/device_mesh.py @@ -129,23 +129,14 @@ def get_device_mesh( return device_mesh -def get_num_data_parallel_ranks(device_mesh: DeviceMesh) -> int: - """Gets the number of data parallel ranks from the device mesh. +def get_num_parallel_ranks(device_mesh: DeviceMesh, parallelism_method: ParallelismDegrees) -> int: + """Gets the number of parallel ranks from the device mesh for a specific parallelism method. Args: device_mesh (DeviceMesh): The device mesh. + parallelism_method (ParallelismDegrees): The parallelism method. Returns: - int: The number of data parallel ranks. + int: The number of parallel ranks for the specified parallelism method. """ - world_size = device_mesh.size() - dp_size = world_size - for parallelism_degree in ( - ParallelismDegrees.TP.value, - ParallelismDegrees.PP.value, - ParallelismDegrees.CP.value, - ): - if parallelism_degree in device_mesh.mesh_dim_names: - dp_size //= device_mesh.size(device_mesh.mesh_dim_names.index(parallelism_degree)) - - return dp_size + return device_mesh.size(device_mesh.mesh_dim_names.index(parallelism_method.value)) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 536f98409..979e245dc 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -31,7 +31,7 @@ class Trainer: def __init__( self, global_rank: int, - num_data_parallel_ranks: int, + num_pipeline_parallel_ranks: int, progress_publisher: MessagePublisher[ProgressUpdate], evaluation_result_publisher: MessagePublisher[EvaluationResultBatch], gradient_acc_steps: int, @@ -64,7 +64,7 @@ def __init__( None """ self.global_rank = global_rank - self.num_data_parallel_ranks = num_data_parallel_ranks + self.num_pipeline_parallel_ranks = num_pipeline_parallel_ranks self.progress_publisher = progress_publisher self.evaluation_result_publisher = evaluation_result_publisher self.gradient_acc_steps = gradient_acc_steps @@ -281,8 +281,10 @@ def train( tensor=cumulated_losses, operation=dist.ReduceOp.SUM, # 1.) summed batch loss / (num batches * world size) - # 2.) last batch loss / world size - post_processing_fun=lambda t: torch.stack([t[0] / t[-1], t[1] / self.num_data_parallel_ranks]), + # 2.) last batch loss / (world size / num_pipeline_parallel_ranks) + post_processing_fun=lambda t: torch.stack( + [t[0] / t[-1], t[1] / dist.get_world_size() * self.num_pipeline_parallel_ranks] + ), ) train_loss_avg, train_loss_last_batch = ( diff --git a/tests/conftest.py b/tests/conftest.py index 9bf289122..9bcc5f1d6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -196,7 +196,7 @@ def trainer(progress_publisher_mock, gradient_clipper_mock): global_num_seen_tokens=0, num_target_tokens=100, num_target_steps=10, - num_data_parallel_ranks=int(os.getenv("WORLD_SIZE")), + num_pipeline_parallel_ranks=1, ) From a394ab0c6d43568d0f4a27baef8de2632cdc2d3e Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 22 Sep 2025 14:53:16 +0200 Subject: [PATCH 42/60] docs: Add example config for pipeline and tensor parallelism --- .../config_lorem_ipsum_long_fsdp2_pp_tp.yaml | 422 ++++++++++++++++++ 1 file changed, 422 insertions(+) create mode 100644 config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml new file mode 100644 index 000000000..f7b4835f6 --- /dev/null +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml @@ -0,0 +1,422 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: 8 + paths: + checkpoint_saving_path: data/checkpoints + train_dataset_path: /raid/s3/opengptx/user/richard-rutmann/data/modalities/gpt2_tokenized/000_00000.pbin + test_dataset_path: ./data/lorem_ipsum.pbin + intervals: + training_log_interval_in_steps: 2 + checkpointing_interval_in_steps: 100000 + evaluation_interval_in_steps: 15 + consistency_enforcement: + enforce_tokens_per_step_consistency: true + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 4 + sequence_length: 16 + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_num_steps + config: + sequence_length: ${settings.step_profile.sequence_length} + num_ranks: ${settings.cuda_env.world_size} + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_steps: ${settings.training_target.num_target_steps} + num_target_steps: 20 + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_multi_dim_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + data_parallel_key: dp_shard + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.test_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: test + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_multi_dim_sampler + config: + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + data_parallel_key: dp_shard + shuffle: true + seed: 42 + drop_last: true + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: test_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + tensor_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 2 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + # maybe better to use the fsdp model and the schedule here + # instead of passing in the staged pipeline? + # If fsdp_model creates a copy then this is not in the scope of + # the staged pipeline. + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: gpt2_tp_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +gpt2_tp_model: + component_key: model + variant_key: gpt2_tp + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: gpipe + num_layers_per_stage: 2 + +model_raw: + component_key: model + variant_key: gpt2 + config: + seed: 42 + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp2 + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + global_rank: ${settings.cuda_env.global_rank} + project: modalities_dcp_tests + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: wandb_storage + config_file_path: ${settings.config_file_path} + +# mfu_calculator: +# component_key: mfu_calculator +# variant_key: gpt2 +# config: +# n_layer: ${model_raw.config.n_layer} +# sequence_length: ${settings.step_profile.sequence_length} +# n_embd: ${model_raw.config.n_embd} +# world_size: ${settings.cuda_env.world_size} +# raw_model: +# instance_key: model_raw +# pass_type: BY_REFERENCE +# wrapped_model: +# instance_key: initialized_model +# pass_type: BY_REFERENCE \ No newline at end of file From cae050ec1d37d8d1ddbb70b39cfd6880487fa084 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 22 Sep 2025 15:06:51 +0200 Subject: [PATCH 43/60] docs: Add docstrings and type hints --- src/modalities/models/gpt2/gpt2_model.py | 11 ++++++++++- src/modalities/models/model_factory.py | 9 +++++++++ src/modalities/running_env/fsdp/device_mesh.py | 3 ++- src/modalities/trainer.py | 18 ++++++++++++++++++ 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 168c6ae26..a1ae9fd2e 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -806,7 +806,7 @@ def __init__( ffn_norm_config (LayerNormWrapperConfig): Config for the feed-forward network normalization module. lm_head_norm_config (LayerNormWrapperConfig): Config for the language model head normalization module. use_weight_tying (bool): Whether to use weight tying. - seed (int, optional): The random seed. Defaults to None. + seed (Optional[int]): The random seed. Defaults to None. """ weight_decay_groups = { "linear": [".attn", ".mlp", ".lm_head.weight"], @@ -910,6 +910,15 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: ... def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor] | torch.Tensor: + """ + Forward pass of the GPT2LLM module. + + Args: + inputs (dict[str, torch.Tensor] | torch.Tensor): Input data. + + Returns: + dict[str, torch.Tensor] | torch.Tensor: Model output. + """ if isinstance(inputs, dict): return {self.prediction_key: self.forward_impl(inputs[self.sample_key])} else: diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index a96a9bc2d..d889c213a 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -58,6 +58,15 @@ class ModelFactory: @staticmethod def _is_model_on_meta_device(model: nn.Module) -> bool: + """ + Checks if all parameters and buffers of the model are on the meta device. + + Args: + model (nn.Module): The model to check. + + Returns: + bool: True if all parameters and buffers are on meta device, False otherwise. + """ meta_counter = 0 param_counter = 0 for _, tensor in itertools.chain(model.named_parameters(), model.named_buffers()): diff --git a/src/modalities/running_env/fsdp/device_mesh.py b/src/modalities/running_env/fsdp/device_mesh.py index 4770951f0..1e70c5323 100644 --- a/src/modalities/running_env/fsdp/device_mesh.py +++ b/src/modalities/running_env/fsdp/device_mesh.py @@ -84,7 +84,8 @@ def get_device_mesh( enable_loss_parallel: bool, world_size: int, ) -> DeviceMesh: - """Gets the device mesh for the specified parallelism degrees. + """ + Gets the device mesh for the specified parallelism degrees. Args: device_type (str): The device type. diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 979e245dc..4b79dbedc 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -28,6 +28,24 @@ class ThroughputAggregationKeys(Enum): class Trainer: + """ + Trainer class for model training. + + Args: + global_rank (int): The global rank. + num_pipeline_parallel_ranks (int): Number of pipeline parallel ranks. + progress_publisher (MessagePublisher[ProgressUpdate]): Progress publisher. + evaluation_result_publisher (MessagePublisher[EvaluationResultBatch]): Evaluation result publisher. + gradient_acc_steps (int): Gradient accumulation steps. + global_num_tokens_per_train_step (int): Global number of tokens per train step. + num_seen_train_steps (int): Number of seen train steps. + global_num_seen_tokens (int): Global number of seen tokens. + num_target_steps (int): Number of target steps. + num_target_tokens (int): Number of target tokens. + gradient_clipper (GradientClipperIF): Gradient clipper. + mfu_calculator (Optional[MFUCalculatorABC]): MFU calculator. + """ + def __init__( self, global_rank: int, From 695223059f15bbcaefc682e10f40322046279422 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 22 Sep 2025 15:09:25 +0200 Subject: [PATCH 44/60] docs: Add type hints and docstrings --- src/modalities/trainer.py | 42 +++++++++++---------------------------- 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 4b79dbedc..bff3c7b47 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -28,24 +28,6 @@ class ThroughputAggregationKeys(Enum): class Trainer: - """ - Trainer class for model training. - - Args: - global_rank (int): The global rank. - num_pipeline_parallel_ranks (int): Number of pipeline parallel ranks. - progress_publisher (MessagePublisher[ProgressUpdate]): Progress publisher. - evaluation_result_publisher (MessagePublisher[EvaluationResultBatch]): Evaluation result publisher. - gradient_acc_steps (int): Gradient accumulation steps. - global_num_tokens_per_train_step (int): Global number of tokens per train step. - num_seen_train_steps (int): Number of seen train steps. - global_num_seen_tokens (int): Global number of seen tokens. - num_target_steps (int): Number of target steps. - num_target_tokens (int): Number of target tokens. - gradient_clipper (GradientClipperIF): Gradient clipper. - mfu_calculator (Optional[MFUCalculatorABC]): MFU calculator. - """ - def __init__( self, global_rank: int, @@ -65,18 +47,18 @@ def __init__( Initializes the Trainer object. Args: - global_rank (int): The global rank to which operates the trainer object. - progress_publisher (MessagePublisher[ProgressUpdate]): The publisher for progress updates. - evaluation_result_publisher (MessagePublisher[EvaluationResultBatch]): - The publisher for evaluation result batches. - gradient_acc_steps (int): The number of gradient accumulation steps. - global_num_tokens_per_train_step (int): The number of global tokens per training step. - num_seen_train_steps (int): The number of training steps already seen. - global_num_seen_tokens (int): The number of tokens already seen. - num_target_steps (int): The target number of training steps. - num_target_tokens (int): The target number of tokens. - gradient_clipper (GradientClipperIF): The gradient clipper. - mfu_calculator (Optional[MFUCalculatorABC]): The MFU calculator. + global_rank (int): The global rank. + num_pipeline_parallel_ranks (int): Number of pipeline parallel ranks. + progress_publisher (MessagePublisher[ProgressUpdate]): Progress publisher. + evaluation_result_publisher (MessagePublisher[EvaluationResultBatch]): Evaluation result publisher. + gradient_acc_steps (int): Gradient accumulation steps. + global_num_tokens_per_train_step (int): Global number of tokens per train step. + num_seen_train_steps (int): Number of seen train steps. + global_num_seen_tokens (int): Global number of seen tokens. + num_target_steps (int): Number of target steps. + num_target_tokens (int): Number of target tokens. + gradient_clipper (GradientClipperIF): Gradient clipper. + mfu_calculator (Optional[MFUCalculatorABC]): MFU calculator. Returns: None From ffa032cd82de6bdb1df4916c1d43baf9bca8b47d Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 22 Sep 2025 17:55:24 +0200 Subject: [PATCH 45/60] fix: Check if parallelism method is initialized --- src/modalities/running_env/fsdp/device_mesh.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/modalities/running_env/fsdp/device_mesh.py b/src/modalities/running_env/fsdp/device_mesh.py index 1e70c5323..5afa7d872 100644 --- a/src/modalities/running_env/fsdp/device_mesh.py +++ b/src/modalities/running_env/fsdp/device_mesh.py @@ -140,4 +140,7 @@ def get_num_parallel_ranks(device_mesh: DeviceMesh, parallelism_method: Parallel Returns: int: The number of parallel ranks for the specified parallelism method. """ - return device_mesh.size(device_mesh.mesh_dim_names.index(parallelism_method.value)) + if parallelism_method.value not in device_mesh.mesh_dim_names: + return 1 + else: + return device_mesh.size(device_mesh.mesh_dim_names.index(parallelism_method.value)) From 8d418a1b8ce90a4747a462a89c2bfd3b36e1a3c2 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 22 Sep 2025 17:55:44 +0200 Subject: [PATCH 46/60] docs: Add new parameter in docstring --- .../training/gradient_clipping/fsdp_gradient_clipper.py | 2 ++ .../gradient_clipping/fsdp_gradient_clipper_config.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py index d4b280a32..faeef0035 100644 --- a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py +++ b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py @@ -109,6 +109,7 @@ def __init__( wrapped_model (FSDP2): The wrapped model. max_norm (float): The maximum norm value for gradient clipping. norm_type (GradientClippingMode, optional): The type of gradient clipping. Defaults to GradientClippingMode. + device_mesh (DeviceMesh, optional): The device mesh used for distributed training. Defaults to None. Returns: None @@ -215,6 +216,7 @@ def __init__( Args: wrapped_model (FSDP2): The wrapped FSDP2 model. norm_type (GradientClippingMode, optional): The type of gradient clipping. Defaults to GradientClippingMode. + device_mesh (DeviceMesh, optional): The device mesh used for distributed training. Defaults to None. Returns: None diff --git a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py index 80ebee2a8..310fb9b60 100644 --- a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py +++ b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py @@ -34,11 +34,13 @@ class FSDP2GradientClipperConfig(FSDP1GradientClipperConfig): max_norm (float): The maximum norm value for gradient clipping. norm_type (GradientClippingMode): The type of gradient clipping to be applied. wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. + device_mesh (PydanticDeviceMeshIFType | None): The device mesh configuration. Attributes: max_norm (float): The maximum norm value for gradient clipping. norm_type (GradientClippingMode): The type of gradient clipping to be applied. wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. + device_mesh (PydanticDeviceMeshIFType | None): The device mesh configuration. """ device_mesh: PydanticDeviceMeshIFType | None = None @@ -68,10 +70,12 @@ class FSDP2DummyGradientClipperConfig(FSDP1DummyGradientClipperConfig): Args: wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. norm_type (GradientClippingMode): The type of gradient clipping to be applied. + device_mesh (PydanticDeviceMeshIFType | None): The device mesh configuration. Attributes: wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. norm_type (GradientClippingMode): The type of gradient clipping to be applied. + device_mesh (PydanticDeviceMeshIFType | None): The device mesh configuration. """ wrapped_model: PydanticPytorchModuleType From fffd0a1fc210fa3e2a1f31b7773c14c73d710303 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 23 Sep 2025 11:40:38 +0200 Subject: [PATCH 47/60] test: Run only one PP only test --- .../pipeline_parallelism/test_pp_fwd_bwd_pass.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py index dabd731b9..d255d62e0 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -72,7 +72,6 @@ def _get_components( "sharding_degree, tp_degree, pp_degree, world_size", [ (2, 1, 2, 4), - (2, 1, 4, 8), (2, 2, 2, 8), ], ) From 049472f9bf68733b1b922ac6f3a36cf0ba68efe7 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Wed, 24 Sep 2025 13:07:20 +0200 Subject: [PATCH 48/60] refactor: Addressed copilot review --- src/modalities/running_env/fsdp/device_mesh.py | 7 +++++++ src/modalities/trainer.py | 1 - .../gradient_clipping/fsdp_gradient_clipper.py | 14 ++++---------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/modalities/running_env/fsdp/device_mesh.py b/src/modalities/running_env/fsdp/device_mesh.py index 5afa7d872..add2cc1da 100644 --- a/src/modalities/running_env/fsdp/device_mesh.py +++ b/src/modalities/running_env/fsdp/device_mesh.py @@ -144,3 +144,10 @@ def get_num_parallel_ranks(device_mesh: DeviceMesh, parallelism_method: Parallel return 1 else: return device_mesh.size(device_mesh.mesh_dim_names.index(parallelism_method.value)) + + +def get_mesh_for_parallelism_method(device_mesh: DeviceMesh | None, parallelism_method: ParallelismDegrees): + if device_mesh is not None and parallelism_method.value in device_mesh.mesh_dim_names: + return device_mesh[parallelism_method.value] + else: + return None diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index bff3c7b47..f81407f02 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -125,7 +125,6 @@ def _train_batch( """ if scheduled_pipeline is not None: pp_schedule = scheduled_pipeline.pp_schedule - # TODO: handle loss and backward in pp # Pipeline Parallel forward / backward inside step() call # with self.train_context(optional_context_parallel_ctx): targets, losses = ( diff --git a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py index faeef0035..c4009cf41 100644 --- a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py +++ b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py @@ -9,7 +9,7 @@ from torch.distributed.tensor import DTensor from modalities.config.lookup_enum import LookupEnum -from modalities.running_env.fsdp.device_mesh import ParallelismDegrees +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_mesh_for_parallelism_method from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF @@ -186,11 +186,7 @@ def clip_grad_norm_( total_norm = total_norm.full_tensor() - pp_mesh = ( - device_mesh[ParallelismDegrees.PP.value] - if device_mesh is not None and ParallelismDegrees.PP.value in device_mesh.mesh_dim_names - else None - ) + pp_mesh = get_mesh_for_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) if pp_mesh is not None: if math.isinf(norm_type): dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) @@ -242,10 +238,8 @@ def clip_gradients(self) -> torch.Tensor: # If only using PP, total_norm will be a local tensor. total_norm = total_norm.full_tensor() - pp_mesh = ( - self.device_mesh[ParallelismDegrees.PP.value] - if self.device_mesh is not None and ParallelismDegrees.PP.value in self.device_mesh.mesh_dim_names - else None + pp_mesh = get_mesh_for_parallelism_method( + device_mesh=self.device_mesh, parallelism_method=ParallelismDegrees.PP ) if pp_mesh is not None: if math.isinf(self.norm_type.value): From 608c7fc9b6d7d4300a62d4548cbbaf1b3059c4ab Mon Sep 17 00:00:00 2001 From: rrutmann Date: Wed, 15 Oct 2025 17:11:41 +0200 Subject: [PATCH 49/60] chore: Remove requirements for python and torch --- pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5a3c84bf1..1396b0cf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,10 @@ [project] name = "modalities" version = "0.3.2" -requires-python = ">=3.10,<3.12" description = "Modalities, a PyTorch-native framework for distributed and reproducible foundation model training." readme = "README.md" dependencies = [ "numpy<2.0", - "torch==2.6.0", "packaging", "tqdm", "pyyaml", From 16c4bc479309baaa7da1eece431df6c5d0ccaf6a Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 17 Oct 2025 14:15:58 +0200 Subject: [PATCH 50/60] fix: Allow dp shard degree 1 --- src/modalities/running_env/fsdp/device_mesh.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/running_env/fsdp/device_mesh.py b/src/modalities/running_env/fsdp/device_mesh.py index add2cc1da..e9f1f3e95 100644 --- a/src/modalities/running_env/fsdp/device_mesh.py +++ b/src/modalities/running_env/fsdp/device_mesh.py @@ -119,7 +119,7 @@ def get_device_mesh( ], strict=True, ): - if dim > 1: + if dim > 1 or name == ParallelismDegrees.DP_SHARD.value: dims.append(dim) names.append(name) names = tuple(names) From f5a10205b6a1568a9748f8931211f1cc425e2398 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 17 Oct 2025 14:17:34 +0200 Subject: [PATCH 51/60] test: Add test for checkpointing with pipeline parallelism --- .../checkpointing/checkpointing_test_utils.py | 36 +++- tests/checkpointing/fsdp2_pp_gpt2_config.yaml | 194 ++++++++++++++++++ ...fsdp2_dcp_checkpoint_loading_and_saving.py | 164 ++++++++++----- 3 files changed, 343 insertions(+), 51 deletions(-) create mode 100644 tests/checkpointing/fsdp2_pp_gpt2_config.yaml diff --git a/tests/checkpointing/checkpointing_test_utils.py b/tests/checkpointing/checkpointing_test_utils.py index 21c4caabe..7a3f241bf 100644 --- a/tests/checkpointing/checkpointing_test_utils.py +++ b/tests/checkpointing/checkpointing_test_utils.py @@ -15,10 +15,17 @@ class CheckpointingTestUtils: @staticmethod def generate_batch(gpt2_model_config: dict): # prepare input and targets + if "settings" in gpt2_model_config: + batch_size = gpt2_model_config["settings"]["step_profile"]["local_train_micro_batch_size"] + else: + batch_size = 8 data = torch.randint( 0, # lowest token_id gpt2_model_config["model_raw"]["config"]["vocab_size"], # highest token_id + 1, i.e. vocab_size - (8, gpt2_model_config["model_raw"]["config"]["sequence_length"] + 1), # (batch_size, sequence_length + 1) + ( + batch_size, + gpt2_model_config["model_raw"]["config"]["sequence_length"] + 1, + ), # (batch_size, sequence_length + 1) ).cuda() batch_input_ids_dict = {gpt2_model_config["model_raw"]["config"]["sample_key"]: data[:, :-1]} batch_target_ids = data[:, 1:] @@ -49,6 +56,33 @@ def forward_backward_pass( optimizer.step() return loss + @staticmethod + def forward_backward_pp_pass( + scheduled_pipeline, + optimizer: Optimizer, + batch_input_ids_dict: dict, + batch_target_ids: torch.Tensor, + ): + pp_schedule = scheduled_pipeline.pp_schedule + # Pipeline Parallel forward / backward inside step() call + # with self.train_context(optional_context_parallel_ctx): + targets, losses = (batch_target_ids.contiguous(), []) if scheduled_pipeline.is_last_pp_stage else (None, None) + + if scheduled_pipeline.is_first_pp_stage: + pp_schedule.step( + batch_input_ids_dict[scheduled_pipeline.model_part.sample_key].contiguous(), + target=targets, + losses=losses, + ) + else: + pp_schedule.step(target=targets, losses=losses) + loss = torch.mean(torch.stack(losses)).to(losses[0].device) if scheduled_pipeline.is_last_pp_stage else None + optimizer.step() + # clear the gradients + optimizer.zero_grad() + + return loss + @staticmethod def get_gpt2_model_from_config(gpt2_model_config_dict: dict) -> GPT2LLM: class GPT2InstantationModel(BaseModel): diff --git a/tests/checkpointing/fsdp2_pp_gpt2_config.yaml b/tests/checkpointing/fsdp2_pp_gpt2_config.yaml new file mode 100644 index 000000000..4a02aa6b2 --- /dev/null +++ b/tests/checkpointing/fsdp2_pp_gpt2_config.yaml @@ -0,0 +1,194 @@ +settings: + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + step_profile: + local_train_micro_batch_size: 8 + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 4 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + # maybe better to use the fsdp model and the schedule here + # instead of passing in the staged pipeline? + # If fsdp_model creates a copy then this is not in the scope of + # the staged pipeline. + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${cuda_env:LOCAL_RANK} + pp_schedule_name: gpipe + num_layers_per_stage: 2 + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +model_raw: + component_key: model + variant_key: gpt2 + config: + sample_key: "input_ids" # TODO reference this + poe_type: NOPE + prediction_key: "logits" # TODO reference this + sequence_length: 256 # TODO reference this (same as sequence length) + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 4 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: gelu + attention_norm_config: + norm_type: rms_norm + config: + ndim: ${model_raw.config.n_embd} + bias: true + epsilon: 1e-5 + ffn_norm_config: + norm_type: rms_norm + config: + ndim: ${model_raw.config.n_embd} + bias: true + epsilon: 1e-5 + lm_head_norm_config: + norm_type: rms_norm + config: + ndim: ${model_raw.config.n_embd} + bias: true + epsilon: 1e-5 + use_weight_tying: false + use_meta_device: true + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0003 + betas: + - 0.9 + - 0.95 + eps: 1.0e-08 + weight_decay: 0.1 + weight_decay_groups_excluded: + - embedding + - layernorm + wrapped_model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${cuda_env:WORLD_SIZE} diff --git a/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py b/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py index e1f0f349c..dc82f19d5 100644 --- a/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py +++ b/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py @@ -1,6 +1,7 @@ import json import os import tempfile +import traceback from copy import deepcopy from pathlib import Path @@ -16,7 +17,8 @@ from modalities.checkpointing.fsdp.fsdp_checkpoint_saving import DCPCheckpointSaving from modalities.checkpointing.stateful.app_state import AppState from modalities.config.config import ProcessGroupBackendType, load_app_config_dict -from modalities.config.pydantic_if_types import PydanticAppStateType +from modalities.config.pydantic_if_types import PydanticAppStateType, PydanticPipelineType +from modalities.models.parallelism.pipeline_parallelism import Pipeline from modalities.training.training_progress import TrainingProgress from tests.checkpointing.checkpointing_test_utils import CheckpointingTestUtils from tests.end2end_tests.custom_components import MultiProcessingCudaEnv @@ -41,8 +43,8 @@ def get_gpt2_model_config_dict(gpt2_model_config_path: Path) -> dict: @pytest.mark.skipif( - torch.cuda.device_count() < 2, - reason="This e2e test requires 2 GPUs", + torch.cuda.device_count() < 4, + reason="This e2e test requires 4 GPUs", ) class TestFSDP2DCPCheckpointing: @staticmethod @@ -57,11 +59,32 @@ class ComponentsInstantiationModel(BaseModel): return components.app_state @staticmethod - def test_save_checkpoint_after_backward_pass(temporary_checkpoint_folder_path: Path, gpt2_model_config_path: Path): - world_size = 2 + def _get_scheduled_pipeline(config_file_path: Path) -> Pipeline: + class ComponentsInstantiationModel(BaseModel): + scheduled_pipeline: PydanticPipelineType + + main_obj = Main(config_file_path) + components: ComponentsInstantiationModel = main_obj.build_components( + components_model_type=ComponentsInstantiationModel + ) + return components.scheduled_pipeline + + @staticmethod + @pytest.mark.parametrize( + "config_filename,world_size,use_pp", + [ + ("fsdp2_gpt2_config.yaml", 2, False), + ("fsdp2_pp_gpt2_config.yaml", 2, True), + ], + ) + def test_save_checkpoint_after_backward_pass( + temporary_checkpoint_folder_path: Path, config_filename: str, world_size: int, use_pp: bool + ): + working_dir = Path(os.path.dirname(__file__)) + config_file_path = working_dir / config_filename mp.spawn( TestFSDP2DCPCheckpointing._test_save_checkpoint_after_backward_pass_impl_wrapper, - args=(world_size, temporary_checkpoint_folder_path, gpt2_model_config_path), + args=(world_size, temporary_checkpoint_folder_path, config_file_path, use_pp), nprocs=world_size, join=True, ) @@ -72,6 +95,7 @@ def _test_save_checkpoint_after_backward_pass_impl_wrapper( world_size: int, temporary_checkpoint_folder_path: Path, gpt2_model_config_path: Path, + use_pp: bool, ): # wraps the actual test function to be able to run it in a distributed multiprocessing setup with MultiProcessingCudaEnv( @@ -79,31 +103,44 @@ def _test_save_checkpoint_after_backward_pass_impl_wrapper( global_rank=process_id, local_rank=process_id, world_size=world_size, - rdvz_port=22356, + rdvz_port=22358, ): - # build all the components for the test - app_state1 = TestFSDP2DCPCheckpointing._get_app_state(config_file_path=gpt2_model_config_path) - app_state2 = TestFSDP2DCPCheckpointing._get_app_state(config_file_path=gpt2_model_config_path) - - gpt2_model_config_dict = get_gpt2_model_config_dict(gpt2_model_config_path=gpt2_model_config_path) - experiment_id = "0" - checkpoint_loading = DCPCheckpointLoading(global_rank=process_id) - checkpoint_saving = DCPCheckpointSaving( - checkpoint_path=temporary_checkpoint_folder_path, - experiment_id=experiment_id, - global_rank=process_id, - ) + try: + # build all the components for the test + app_state1 = TestFSDP2DCPCheckpointing._get_app_state(config_file_path=gpt2_model_config_path) + app_state2 = TestFSDP2DCPCheckpointing._get_app_state(config_file_path=gpt2_model_config_path) - # run the test - TestFSDP2DCPCheckpointing._test_save_checkpoint_after_backward_pass_impl( - app_state1=app_state1, - app_state2=app_state2, - gpt2_model_config_dict=gpt2_model_config_dict, - checkpoint_loading=checkpoint_loading, - checkpoint_saving=checkpoint_saving, - temporary_checkpoint_folder_path=temporary_checkpoint_folder_path, - experiment_id=experiment_id, - ) + if use_pp: + app_state1.scheduled_pipeline = TestFSDP2DCPCheckpointing._get_scheduled_pipeline( + config_file_path=gpt2_model_config_path + ) + app_state2.scheduled_pipeline = TestFSDP2DCPCheckpointing._get_scheduled_pipeline( + config_file_path=gpt2_model_config_path + ) + + gpt2_model_config_dict = get_gpt2_model_config_dict(gpt2_model_config_path=gpt2_model_config_path) + experiment_id = "0" + checkpoint_loading = DCPCheckpointLoading(global_rank=process_id) + checkpoint_saving = DCPCheckpointSaving( + checkpoint_path=temporary_checkpoint_folder_path, + experiment_id=experiment_id, + global_rank=process_id, + ) + + # run the test + TestFSDP2DCPCheckpointing._test_save_checkpoint_after_backward_pass_impl( + app_state1=app_state1, + app_state2=app_state2, + gpt2_model_config_dict=gpt2_model_config_dict, + checkpoint_loading=checkpoint_loading, + checkpoint_saving=checkpoint_saving, + temporary_checkpoint_folder_path=temporary_checkpoint_folder_path, + experiment_id=experiment_id, + ) + except Exception as e: + print(f"Exception in _forward_step_with_pp: {e}") + traceback.print_exc() # <-- Add this line to print the full stack trace + raise e @staticmethod def _test_save_checkpoint_after_backward_pass_impl( @@ -139,13 +176,21 @@ def _test_save_checkpoint_after_backward_pass_impl( # run backward pass batch_input_ids_dict, batch_target_ids = CheckpointingTestUtils.generate_batch(gpt2_model_config_dict) - loss_0 = CheckpointingTestUtils.forward_backward_pass( - prediction_key=prediction_key, - model=app_state1.model, - optimizer=app_state1.optimizer, - batch_input_ids_dict=batch_input_ids_dict, - batch_target_ids=batch_target_ids, - ) + if hasattr(app_state1, "scheduled_pipeline"): + loss_0 = CheckpointingTestUtils.forward_backward_pp_pass( + scheduled_pipeline=app_state1.scheduled_pipeline, + optimizer=app_state1.optimizer, + batch_input_ids_dict=batch_input_ids_dict, + batch_target_ids=batch_target_ids, + ) + else: + loss_0 = CheckpointingTestUtils.forward_backward_pass( + prediction_key=prediction_key, + model=app_state1.model, + optimizer=app_state1.optimizer, + batch_input_ids_dict=batch_input_ids_dict, + batch_target_ids=batch_target_ids, + ) # save the updated model and optimizer states for later comparisons updated_model_parameters = CheckpointingTestUtils.clone_parameters(app_state1.model) @@ -198,21 +243,40 @@ def _test_save_checkpoint_after_backward_pass_impl( loaded_and_updated_optimizer_state_dict = deepcopy(app_state1.optimizer.state_dict()) # perform another forward pass and backward pass for the previous and the loaded model - loss_1 = CheckpointingTestUtils.forward_backward_pass( - prediction_key=prediction_key, - model=app_state1.model, - optimizer=app_state1.optimizer, - batch_input_ids_dict=batch_input_ids_dict, - batch_target_ids=batch_target_ids, - ) + if hasattr(app_state1, "scheduled_pipeline"): + try: + # loss_1 = CheckpointingTestUtils.forward_backward_pp_pass( + # scheduled_pipeline=app_state1.scheduled_pipeline, + # optimizer=app_state1.optimizer, + # batch_input_ids_dict=batch_input_ids_dict, + # batch_target_ids=batch_target_ids, + # ) + loss_2 = CheckpointingTestUtils.forward_backward_pp_pass( + scheduled_pipeline=app_state2.scheduled_pipeline, + optimizer=app_state2.optimizer, + batch_input_ids_dict=batch_input_ids_dict, + batch_target_ids=batch_target_ids, + ) + except Exception as e: + print(f"Exception in _forward_step_with_pp: {e}") + traceback.print_exc() # <-- Add this line to print the full stack trace + raise e + else: + loss_1 = CheckpointingTestUtils.forward_backward_pass( + prediction_key=prediction_key, + model=app_state1.model, + optimizer=app_state1.optimizer, + batch_input_ids_dict=batch_input_ids_dict, + batch_target_ids=batch_target_ids, + ) - loss_2 = CheckpointingTestUtils.forward_backward_pass( - prediction_key=prediction_key, - model=app_state2.model, - optimizer=app_state2.optimizer, - batch_input_ids_dict=batch_input_ids_dict, - batch_target_ids=batch_target_ids, - ) + loss_2 = CheckpointingTestUtils.forward_backward_pass( + prediction_key=prediction_key, + model=app_state2.model, + optimizer=app_state2.optimizer, + batch_input_ids_dict=batch_input_ids_dict, + batch_target_ids=batch_target_ids, + ) assert loss_1 == loss_2, f"loss_1 = {loss_1} does not equal loss_2 = {loss_2}" assert loss_1 < loss_0, f"loss_1 = {loss_1} is not less than loss_0 = {loss_0}" From 9d1f107d587cc6d6709609c41c20afe2bd0a410e Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Fri, 17 Oct 2025 18:44:05 +0200 Subject: [PATCH 52/60] fix(parallelism): Building model stages in PP now also filters the model's weight_decay_groups. --- .../parallelism/pipeline_parallelism.py | 31 +++++++++++++++---- .../optimizers/optimizer_factory.py | 12 +++---- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py index 006d97a55..9d7e97718 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism.py +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -3,8 +3,9 @@ # licensed under the BSD 3-Clause License. import copy +import re from enum import Enum -from typing import Any, Optional, Type +from typing import Any, Optional, Type, cast import torch import torch.nn as nn @@ -13,6 +14,7 @@ from torch.distributed.pipelining.schedules import PipelineScheduleSingle, get_schedule_class from modalities.loss_functions import Loss +from modalities.models.model import NNModel from modalities.models.parallelism.stages_generator import StagesGenerator from modalities.running_env.fsdp.device_mesh import ParallelismDegrees from modalities.utils.logger_utils import get_logger @@ -83,13 +85,13 @@ class PipelineFactory: @staticmethod def get_pipeline( - pp_stage: PipelineStage, model_part: nn.Module, pp_schedule: Optional[PipelineScheduleSingle] = None + pp_stage: PipelineStage, model_part: NNModel, pp_schedule: Optional[PipelineScheduleSingle] = None ) -> Pipeline: return Pipeline(pp_stage=pp_stage, model_part=model_part, pp_schedule=pp_schedule) @staticmethod def get_staged_pipeline( - whole_model: nn.Module, + whole_model: NNModel, stages_generator: StagesGenerator, device_mesh: DeviceMesh, local_rank: int, @@ -128,12 +130,12 @@ def get_staged_pipeline( @staticmethod def _get_split_model( - whole_model: nn.Module, + whole_model: NNModel, schedule_class: Type[PipelineScheduleSingle], pp_mesh: DeviceMesh, device: torch.device, fqns_per_stage: list[list[str]], - ) -> tuple[PipelineStage, nn.Module]: + ) -> tuple[PipelineStage, NNModel]: def get_stage_id_of_pp_rank(pp_mesh: DeviceMesh): # NOTE: torch titan a more complicated way to get the stage id of pp rank # since they also allow for multi-stage schedules @@ -164,7 +166,7 @@ def _get_fqn_tree(fqns: list[str]) -> dict[str, Any]: def _build_stage_from_modules( fqn_tree: dict[str, Any], module: nn.Module, module_name: Optional[str] = None - ) -> tuple[PipelineStage, nn.Module]: + ) -> nn.Module: if isinstance(module, nn.ModuleDict): if module_name not in fqn_tree: dict_modules = nn.ModuleDict({}) @@ -239,6 +241,8 @@ def _build_stage_from_modules( whole_model = copy.deepcopy(whole_model) fqn_tree = _get_fqn_tree(module_names) stage_modules = _build_stage_from_modules(fqn_tree, whole_model) + stage_modules = cast(NNModel, stage_modules) + PipelineFactory._filter_weight_decay_groups_(stage_modules) stage = PipelineStage( submodule=stage_modules, stage_index=stage_idx, @@ -248,6 +252,21 @@ def _build_stage_from_modules( ) return stage, stage_modules + @staticmethod + def _filter_weight_decay_groups_(stage_modules: NNModel): + params = {name for name, parameter in stage_modules.named_parameters() if parameter.requires_grad} + for group_list in stage_modules.weight_decay_groups.values(): + remove_from_group = [ + group_entry + for group_entry in group_list + if all([not bool(re.search(group_entry, name)) for name in params]) + ] + for remove in remove_from_group: + group_list.remove(remove) + empty_group_keys = [k for k, v in stage_modules.weight_decay_groups.items() if len(v) == 0] + for key in empty_group_keys: + del stage_modules.weight_decay_groups[key] + @staticmethod def get_scheduled_pipeline( loss_fn: Loss, pp_schedule_name: str, batch_size: int, microbatch_size: int, pp_degree: int, pipeline: Pipeline diff --git a/src/modalities/optimizers/optimizer_factory.py b/src/modalities/optimizers/optimizer_factory.py index c430e82a1..5a0ae2bdc 100644 --- a/src/modalities/optimizers/optimizer_factory.py +++ b/src/modalities/optimizers/optimizer_factory.py @@ -12,6 +12,7 @@ from modalities.exceptions import OptimizerError from modalities.models.model import NNModel from modalities.util import get_local_number_of_trainable_parameters, print_rank_0 +from modalities.utils.logger_utils import get_logger from modalities.utils.typing_utils import FSDPX OptimizerGroups = list[dict[str, list[nn.Parameter] | float]] @@ -80,7 +81,7 @@ def get_optimizer_groups(model: FSDP, weight_decay: float, weight_decay_groups_e optimizer_groups_names = ["all"] else: # there will be N optimizer groups, i.e. one for each model parameter group - _assert_existence_of_weight_decay_groups_excluded(model, weight_decay_groups_excluded) + _check_existence_of_weight_decay_groups_excluded(model, weight_decay_groups_excluded) optimizer_groups, optimizer_groups_names = _create_optimizer_groups( model, weight_decay, weight_decay_groups_excluded ) @@ -90,9 +91,7 @@ def get_optimizer_groups(model: FSDP, weight_decay: float, weight_decay_groups_e return optimizer_groups -def _assert_existence_of_weight_decay_groups_excluded( - model: nn.Module, weight_decay_groups_excluded: list[str] -) -> None: +def _check_existence_of_weight_decay_groups_excluded(model: nn.Module, weight_decay_groups_excluded: list[str]) -> None: """ checks the existence of all groups that are to be excluded from weight decay @@ -113,9 +112,10 @@ def _assert_existence_of_weight_decay_groups_excluded( weight_decay_groups = nn_model.weight_decay_groups for group in weight_decay_groups_excluded: if group not in weight_decay_groups.keys(): - raise OptimizerError( + get_logger(name="optimizer_factory").warning( f"group = {group} specified in weight_decay_groups_excluded is not " - + f"in models optimizer_module_groups = {list(weight_decay_groups.keys())}" + + f"in models optimizer_module_groups = {list(weight_decay_groups.keys())}. " + + "(This might be due to pipeline parallelism and is not necessarily an error.)" ) From dfc1bdebbcac047daf9f74f7d521d219a6cb48cb Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Fri, 17 Oct 2025 18:46:02 +0200 Subject: [PATCH 53/60] test(checkpointing): Some fixes for pp checkpointing test. --- ...fsdp2_dcp_checkpoint_loading_and_saving.py | 56 ++++++++----------- 1 file changed, 24 insertions(+), 32 deletions(-) diff --git a/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py b/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py index dc82f19d5..77fcd3edb 100644 --- a/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py +++ b/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py @@ -18,7 +18,6 @@ from modalities.checkpointing.stateful.app_state import AppState from modalities.config.config import ProcessGroupBackendType, load_app_config_dict from modalities.config.pydantic_if_types import PydanticAppStateType, PydanticPipelineType -from modalities.models.parallelism.pipeline_parallelism import Pipeline from modalities.training.training_progress import TrainingProgress from tests.checkpointing.checkpointing_test_utils import CheckpointingTestUtils from tests.end2end_tests.custom_components import MultiProcessingCudaEnv @@ -48,26 +47,26 @@ def get_gpt2_model_config_dict(gpt2_model_config_path: Path) -> dict: ) class TestFSDP2DCPCheckpointing: @staticmethod - def _get_app_state(config_file_path: Path) -> AppState: - class ComponentsInstantiationModel(BaseModel): - app_state: PydanticAppStateType + def _get_app_state(config_file_path: Path, use_pp: bool = False) -> AppState: + if use_pp: - main_obj = Main(config_file_path) - components: ComponentsInstantiationModel = main_obj.build_components( - components_model_type=ComponentsInstantiationModel - ) - return components.app_state + class ComponentsInstantiationModel(BaseModel): + app_state: PydanticAppStateType + scheduled_pipeline: PydanticPipelineType - @staticmethod - def _get_scheduled_pipeline(config_file_path: Path) -> Pipeline: - class ComponentsInstantiationModel(BaseModel): - scheduled_pipeline: PydanticPipelineType + else: + + class ComponentsInstantiationModel(BaseModel): + app_state: PydanticAppStateType main_obj = Main(config_file_path) components: ComponentsInstantiationModel = main_obj.build_components( components_model_type=ComponentsInstantiationModel ) - return components.scheduled_pipeline + app_state = components.app_state + if use_pp: + app_state.scheduled_pipeline = components.scheduled_pipeline + return app_state @staticmethod @pytest.mark.parametrize( @@ -103,20 +102,12 @@ def _test_save_checkpoint_after_backward_pass_impl_wrapper( global_rank=process_id, local_rank=process_id, world_size=world_size, - rdvz_port=22358, + rdvz_port=22355, ): try: # build all the components for the test - app_state1 = TestFSDP2DCPCheckpointing._get_app_state(config_file_path=gpt2_model_config_path) - app_state2 = TestFSDP2DCPCheckpointing._get_app_state(config_file_path=gpt2_model_config_path) - - if use_pp: - app_state1.scheduled_pipeline = TestFSDP2DCPCheckpointing._get_scheduled_pipeline( - config_file_path=gpt2_model_config_path - ) - app_state2.scheduled_pipeline = TestFSDP2DCPCheckpointing._get_scheduled_pipeline( - config_file_path=gpt2_model_config_path - ) + app_state1 = TestFSDP2DCPCheckpointing._get_app_state(gpt2_model_config_path, use_pp) + app_state2 = TestFSDP2DCPCheckpointing._get_app_state(gpt2_model_config_path, use_pp) gpt2_model_config_dict = get_gpt2_model_config_dict(gpt2_model_config_path=gpt2_model_config_path) experiment_id = "0" @@ -245,12 +236,12 @@ def _test_save_checkpoint_after_backward_pass_impl( # perform another forward pass and backward pass for the previous and the loaded model if hasattr(app_state1, "scheduled_pipeline"): try: - # loss_1 = CheckpointingTestUtils.forward_backward_pp_pass( - # scheduled_pipeline=app_state1.scheduled_pipeline, - # optimizer=app_state1.optimizer, - # batch_input_ids_dict=batch_input_ids_dict, - # batch_target_ids=batch_target_ids, - # ) + loss_1 = CheckpointingTestUtils.forward_backward_pp_pass( + scheduled_pipeline=app_state1.scheduled_pipeline, + optimizer=app_state1.optimizer, + batch_input_ids_dict=batch_input_ids_dict, + batch_target_ids=batch_target_ids, + ) loss_2 = CheckpointingTestUtils.forward_backward_pp_pass( scheduled_pipeline=app_state2.scheduled_pipeline, optimizer=app_state2.optimizer, @@ -278,7 +269,8 @@ def _test_save_checkpoint_after_backward_pass_impl( batch_target_ids=batch_target_ids, ) assert loss_1 == loss_2, f"loss_1 = {loss_1} does not equal loss_2 = {loss_2}" - assert loss_1 < loss_0, f"loss_1 = {loss_1} is not less than loss_0 = {loss_0}" + if loss_1 is not None: + assert loss_1 < loss_0, f"loss_1 = {loss_1} is not less than loss_0 = {loss_0}" # check that the model and optimizer states after each backward pass are as expected # model weights From cd9f5951c64c06035149a0e75454a4b0eed611e0 Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Mon, 20 Oct 2025 11:38:47 +0200 Subject: [PATCH 54/60] test(checkpointing): Made dcp checkpointing test terminate correctly when failing in one of multiple subprocesses. --- ...fsdp2_dcp_checkpoint_loading_and_saving.py | 136 ++++++++++++++++-- 1 file changed, 127 insertions(+), 9 deletions(-) diff --git a/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py b/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py index 77fcd3edb..482d0cde6 100644 --- a/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py +++ b/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py @@ -1,6 +1,9 @@ import json +import logging +import multiprocessing as py_mp import os import tempfile +import time import traceback from copy import deepcopy from pathlib import Path @@ -81,13 +84,21 @@ def test_save_checkpoint_after_backward_pass( ): working_dir = Path(os.path.dirname(__file__)) config_file_path = working_dir / config_filename - mp.spawn( + # Use a Manager queue so child processes can report exceptions to the parent. + manager = py_mp.Manager() + error_queue = manager.Queue() + + # Start child processes without joining so the parent can monitor a shared queue + # and terminate remaining workers immediately if any child fails. + proc_ctx = mp.spawn( TestFSDP2DCPCheckpointing._test_save_checkpoint_after_backward_pass_impl_wrapper, - args=(world_size, temporary_checkpoint_folder_path, config_file_path, use_pp), + args=(world_size, temporary_checkpoint_folder_path, config_file_path, use_pp, error_queue), nprocs=world_size, - join=True, + join=False, ) + TestFSDP2DCPCheckpointing._monitor_child_processes(manager, error_queue, proc_ctx) + @staticmethod def _test_save_checkpoint_after_backward_pass_impl_wrapper( process_id: int, @@ -95,6 +106,7 @@ def _test_save_checkpoint_after_backward_pass_impl_wrapper( temporary_checkpoint_folder_path: Path, gpt2_model_config_path: Path, use_pp: bool, + error_queue: "py_mp.managers.SyncManager.Queue", ): # wraps the actual test function to be able to run it in a distributed multiprocessing setup with MultiProcessingCudaEnv( @@ -102,7 +114,7 @@ def _test_save_checkpoint_after_backward_pass_impl_wrapper( global_rank=process_id, local_rank=process_id, world_size=world_size, - rdvz_port=22355, + rdvz_port=22353, ): try: # build all the components for the test @@ -129,9 +141,14 @@ def _test_save_checkpoint_after_backward_pass_impl_wrapper( experiment_id=experiment_id, ) except Exception as e: - print(f"Exception in _forward_step_with_pp: {e}") - traceback.print_exc() # <-- Add this line to print the full stack trace - raise e + tb = traceback.format_exc() + logging.error(f"Process {process_id} encountered an error:\n{e}") + logging.error(tb) + try: + error_queue.put((process_id, tb)) + except Exception: + logging.error("Failed to put exception info into error queue.") + os._exit(1) @staticmethod def _test_save_checkpoint_after_backward_pass_impl( @@ -250,8 +267,8 @@ def _test_save_checkpoint_after_backward_pass_impl( ) except Exception as e: print(f"Exception in _forward_step_with_pp: {e}") - traceback.print_exc() # <-- Add this line to print the full stack trace - raise e + traceback.print_exc() + raise else: loss_1 = CheckpointingTestUtils.forward_backward_pass( prediction_key=prediction_key, @@ -307,3 +324,104 @@ def _test_save_checkpoint_after_backward_pass_impl( CheckpointingTestUtils.assert_equality_optimizer_state( app_state1.optimizer.state_dict(), updated_optimizer_state_dict, must_be_equal=False ) + + @staticmethod + def _monitor_child_processes(manager, error_queue, proc_ctx): + # Normalize the return value from mp.spawn. When join=False it often + # returns a ProcessContext-like object that may expose a `processes` + # attribute. Other implementations may return an iterable of Process + # objects. Build a `processes` list defensively so we can monitor and + # terminate child processes below without assuming a particular type. + processes = [] + if proc_ctx is None: + processes = [] + else: + # common attribute names that might hold the list of processes + candidate_attrs = ["processes", "_processes", "workers", "process_list", "processes_"] + found = False + for attr in candidate_attrs: + if hasattr(proc_ctx, attr): + ps = getattr(proc_ctx, attr) + try: + processes = list(ps) + except Exception: + processes = [ps] + found = True + break + if not found: + # If proc_ctx itself is iterable, exhaust it into a list + try: + processes = list(proc_ctx) + except Exception: + # Fallback: if proc_ctx behaves like a single process-like + # object (has terminate/is_alive/join), wrap it in a list. + if hasattr(proc_ctx, "terminate") or hasattr(proc_ctx, "is_alive") or hasattr(proc_ctx, "join"): + processes = [proc_ctx] + else: + processes = [] + + # Monitor the error queue and child processes. If any child reports an exception, + # terminate the other workers and raise the error in the parent to fail the test fast. + try: + # Loop until all processes finished or an error is reported + while True: + # If an error was reported by any child process, terminate remaining children + if not error_queue.empty(): + proc_id, tb = error_queue.get() + # terminate and join all processes (or the proc_ctx wrapper) + for p in processes: + try: + if hasattr(p, "is_alive"): + alive = p.is_alive() + elif hasattr(p, "exitcode"): + alive = getattr(p, "exitcode") is None + else: + alive = True + if alive and hasattr(p, "terminate"): + p.terminate() + except Exception: + pass + # If we didn't find individual process objects but proc_ctx + # exposes a terminate method, call it as a fallback. + try: + if not processes and hasattr(proc_ctx, "terminate"): + proc_ctx.terminate() + except Exception: + pass + + for p in processes: + try: + if hasattr(p, "join"): + p.join(timeout=5) + except Exception: + pass + try: + if hasattr(proc_ctx, "join"): + proc_ctx.join(timeout=1) + except Exception: + pass + raise AssertionError(f"Child process {proc_id} raised an exception:\n{tb}") + + # If all processes have finished, break + all_finished = all((not p.is_alive()) for p in processes) + if all_finished: + # join them to collect exitcodes + for p in processes: + try: + p.join() + except Exception: + pass + # If we have a ProcessContext, call its join to clean up as well + try: + if hasattr(proc_ctx, "join"): + proc_ctx.join(timeout=1) + except Exception: + pass + break + + time.sleep(0.05) + finally: + try: + manager.shutdown() + except Exception: + pass From edf7a4e5b818cd9916c114e31d9264d69884725a Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Tue, 21 Oct 2025 11:39:02 +0200 Subject: [PATCH 55/60] test(checkpointing): Checkpointing equality tests now explicitly only check the local tensor. --- .../checkpointing/checkpointing_test_utils.py | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/tests/checkpointing/checkpointing_test_utils.py b/tests/checkpointing/checkpointing_test_utils.py index 7a3f241bf..c350ccbc8 100644 --- a/tests/checkpointing/checkpointing_test_utils.py +++ b/tests/checkpointing/checkpointing_test_utils.py @@ -1,5 +1,6 @@ import torch from pydantic import BaseModel +from torch.distributed.tensor import DTensor from torch.nn import CrossEntropyLoss from torch.optim import Optimizer @@ -128,19 +129,32 @@ def assert_equality_optimizer_state( state_2 = optimizer_2_state[param_group_id] assert set(state_1.keys()) == set(state_2.keys()) for state_key in state_1.keys(): - if must_be_equal: - assert torch.equal( - state_1[state_key], state_2[state_key] - ), "_assert_equality_optimizer_state failed (must_be_equal = True)" - else: - assert not torch.equal( - state_1[state_key], state_2[state_key] - ), "_assert_equality_optimizer_state failed (must_be_equal = False)" + CheckpointingTestUtils.assert_equality_two_tensors( + tensor_1=state_1[state_key], + tensor_2=state_2[state_key], + must_be_equal=must_be_equal, + msg_on_failure="_assert_equality_optimizer_state failed", + ) @staticmethod def assert_equality_two_models(params_1: list[torch.Tensor], params_2: list[torch.Tensor], must_be_equal: bool): for p1, p2 in zip(params_1, params_2): - if must_be_equal: - assert torch.equal(p1, p2), "_assert_equality_two_models failed (must_be_equal = True)" - else: - assert not torch.equal(p1, p2), "_assert_equality_two_models failed (must_be_equal = False)" + CheckpointingTestUtils.assert_equality_two_tensors( + tensor_1=p1, + tensor_2=p2, + must_be_equal=must_be_equal, + msg_on_failure="_assert_equality_two_models failed", + ) + + @staticmethod + def assert_equality_two_tensors( + tensor_1: torch.Tensor, tensor_2: torch.Tensor, must_be_equal: bool, msg_on_failure: str = "" + ): + if isinstance(tensor_1, DTensor): + assert isinstance(tensor_2, DTensor), f"{msg_on_failure} (type mismatch with DTensor)" + tensor_1 = tensor_1.to_local() + tensor_2 = tensor_2.to_local() + if must_be_equal: + assert torch.equal(tensor_1, tensor_2), f"{msg_on_failure} (must_be_equal = True)" + else: + assert not torch.equal(tensor_1, tensor_2), f"{msg_on_failure} (must_be_equal = False)" From abcf235d0824ac620e135c27628fdc87f38b417d Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 21 Oct 2025 12:06:51 +0200 Subject: [PATCH 56/60] fix: Use ModuleDict for transformer layers for correct checkpointing with pp --- src/modalities/models/gpt2/gpt2_model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index a1ae9fd2e..f50a146e0 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -845,9 +845,9 @@ def __init__( wte=nn.Embedding(num_embeddings=vocab_size, embedding_dim=n_embd), wpe=wpe, drop=nn.Dropout(dropout), - h=nn.ModuleList( - [ - GPT2Block( + h=nn.ModuleDict( + { + str(layer_id): GPT2Block( n_embd=n_embd, bias=bias, n_head_q=n_head_q, @@ -863,8 +863,8 @@ def __init__( attention_norm=attention_norm_config.norm_type.value(**dict(attention_norm_config.config)), ffn_norm=ffn_norm_config.norm_type.value(**dict(ffn_norm_config.config)), ) - for _ in range(n_layer) - ] + for layer_id in range(n_layer) + } ), lm_head_norm=lm_head_norm_config.norm_type.value(**dict(lm_head_norm_config.config)), # NOTE: If we make the bias configurable, we must update the number of parameters calculation @@ -952,8 +952,8 @@ def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor: # TODO: use drop out also without absolute position embedding? h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h - for block in self.transformer.h: - h = block(h) + for layer_id in self.transformer.h: + h = self.transformer.h[layer_id](h) h = self.transformer.lm_head_norm(h) if hasattr(self.transformer, "lm_head_norm") else h h = self.transformer.lm_head(h) if hasattr(self.transformer, "lm_head") else h return h From 554cd3943ad6c5832b842e3b738f5958c29f3fd8 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 21 Oct 2025 16:09:13 +0200 Subject: [PATCH 57/60] chore: Rename layer_id to layer_idx --- src/modalities/models/gpt2/gpt2_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index f50a146e0..3e27ec5d5 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -847,7 +847,7 @@ def __init__( drop=nn.Dropout(dropout), h=nn.ModuleDict( { - str(layer_id): GPT2Block( + str(layer_idx): GPT2Block( n_embd=n_embd, bias=bias, n_head_q=n_head_q, @@ -863,7 +863,7 @@ def __init__( attention_norm=attention_norm_config.norm_type.value(**dict(attention_norm_config.config)), ffn_norm=ffn_norm_config.norm_type.value(**dict(ffn_norm_config.config)), ) - for layer_id in range(n_layer) + for layer_idx in range(n_layer) } ), lm_head_norm=lm_head_norm_config.norm_type.value(**dict(lm_head_norm_config.config)), @@ -952,8 +952,8 @@ def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor: # TODO: use drop out also without absolute position embedding? h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h - for layer_id in self.transformer.h: - h = self.transformer.h[layer_id](h) + for layer_idx in self.transformer.h: + h = self.transformer.h[layer_idx](h) h = self.transformer.lm_head_norm(h) if hasattr(self.transformer, "lm_head_norm") else h h = self.transformer.lm_head(h) if hasattr(self.transformer, "lm_head") else h return h From 484815e5ae8765f53015ab836b8ce54f1cbbf0c2 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 21 Oct 2025 16:52:49 +0200 Subject: [PATCH 58/60] test: Adapt tests to new gpt2 model structure --- src/modalities/conversion/gpt2/conversion_model.py | 8 ++++---- .../activation_checkpointing.py | 4 ++-- .../test_fsdp2_dcp_checkpoint_loading_and_saving.py | 6 +++--- tests/conversion/gpt2/helper.py | 10 +++++----- tests/test_torch_compile.py | 11 +++++++---- 5 files changed, 21 insertions(+), 18 deletions(-) diff --git a/src/modalities/conversion/gpt2/conversion_model.py b/src/modalities/conversion/gpt2/conversion_model.py index 7b06e3ec0..89fbf194a 100644 --- a/src/modalities/conversion/gpt2/conversion_model.py +++ b/src/modalities/conversion/gpt2/conversion_model.py @@ -136,10 +136,10 @@ def _copy_weights_model(hf_model: GPT2ForCausalLM, modalities_model: GPT2LLM): modalities_model (GPT2LLM): The modalities model from which the weights will be copied. """ hf_model.model.embed_tokens.weight.data.copy_(modalities_model.transformer.wte.weight.data) - for hf_layer, modalities_layer in zip(hf_model.model.layers, modalities_model.transformer.h): - _copy_weights_attention(hf_layer, modalities_layer) - _copy_weights_mlp(hf_layer, modalities_layer) - _copy_weights_layer_norms(hf_layer, modalities_layer) + for hf_layer, modalities_layer_idx in zip(hf_model.model.layers, modalities_model.transformer.h): + _copy_weights_attention(hf_layer, modalities_model.transformer.h[modalities_layer_idx]) + _copy_weights_mlp(hf_layer, modalities_model.transformer.h[modalities_layer_idx]) + _copy_weights_layer_norms(hf_layer, modalities_model.transformer.h[modalities_layer_idx]) _copy_weights_base_modules(hf_model.lm_head, modalities_model.transformer.lm_head) _copy_weights_base_modules(hf_model.model.norm, modalities_model.transformer.lm_head_norm) diff --git a/src/modalities/training/activation_checkpointing/activation_checkpointing.py b/src/modalities/training/activation_checkpointing/activation_checkpointing.py index 3cecf192d..0c194c350 100644 --- a/src/modalities/training/activation_checkpointing/activation_checkpointing.py +++ b/src/modalities/training/activation_checkpointing/activation_checkpointing.py @@ -135,8 +135,8 @@ def apply_activation_checkpointing_( raise ValueError(f"Unknown activation checkpointing variant: {ac_variant}") layers = model.get_submodule(layers_fqn) - if not isinstance(layers, nn.ModuleList): - raise ValueError(f"layers_fqn {layers_fqn} does not reference a ModuleList") + if not isinstance(layers, nn.ModuleDict): + raise ValueError(f"layers_fqn {layers_fqn} does not reference a ModuleDict") print_rank_0(f"Applying activation checkpointing to {len(list(layers.named_children()))} layers...") diff --git a/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py b/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py index 482d0cde6..bcbdbd32b 100644 --- a/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py +++ b/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py @@ -114,7 +114,7 @@ def _test_save_checkpoint_after_backward_pass_impl_wrapper( global_rank=process_id, local_rank=process_id, world_size=world_size, - rdvz_port=22353, + rdvz_port=22354, ): try: # build all the components for the test @@ -248,8 +248,8 @@ def _test_save_checkpoint_after_backward_pass_impl( ) loaded_and_updated_model_parameters = CheckpointingTestUtils.clone_parameters(app_state1.model) - loaded_and_updated_optimizer_state_dict = deepcopy(app_state1.optimizer.state_dict()) - + loaded_and_updated_optimizer_state_dict = deepcopy(app_state1.optimizer.state_dict()) + # perform another forward pass and backward pass for the previous and the loaded model if hasattr(app_state1, "scheduled_pipeline"): try: diff --git a/tests/conversion/gpt2/helper.py b/tests/conversion/gpt2/helper.py index 328633ccb..99adbacbc 100644 --- a/tests/conversion/gpt2/helper.py +++ b/tests/conversion/gpt2/helper.py @@ -6,14 +6,14 @@ def check_same_weight_model(converted_model: GPT2ForCausalLM, modalities_model: GPT2LLM): - converted_model.to(device=modalities_model.transformer.h[0].attn.q_attn.weight.device) + converted_model.to(device=modalities_model.transformer.h["0"].attn.q_attn.weight.device) assert torch.equal(converted_model.model.embed_tokens.weight, modalities_model.transformer.wte.weight) - for i, (llama_layer, modalities_layer) in enumerate( + for i, (llama_layer, modalities_layer_idx) in enumerate( zip(converted_model.model.layers, modalities_model.transformer.h) ): - check_same_weight_attention(llama_layer, modalities_layer) - check_same_weight_mlp(llama_layer, modalities_layer) - check_same_weight_layer_norms(llama_layer, modalities_layer) + check_same_weight_attention(llama_layer, modalities_model.transformer.h[modalities_layer_idx]) + check_same_weight_mlp(llama_layer, modalities_model.transformer.h[modalities_layer_idx]) + check_same_weight_layer_norms(llama_layer, modalities_model.transformer.h[modalities_layer_idx]) check_same_weight_base_modules(converted_model.lm_head, modalities_model.transformer.lm_head) check_same_weight_base_modules(converted_model.model.norm, modalities_model.transformer.lm_head_norm) diff --git a/tests/test_torch_compile.py b/tests/test_torch_compile.py index fab2ed217..59ae6ecb9 100644 --- a/tests/test_torch_compile.py +++ b/tests/test_torch_compile.py @@ -1,3 +1,6 @@ + +import copy + import pytest import torch.nn as nn @@ -57,7 +60,7 @@ def gpt2_model(): def test_get_compiled_model_compiles_blocks(gpt2_model): - original_blocks = list(gpt2_model.transformer.h) + original_model = copy.deepcopy(gpt2_model) original_wte = gpt2_model.transformer.wte original_lm_head = gpt2_model.transformer.lm_head @@ -65,9 +68,9 @@ def test_get_compiled_model_compiles_blocks(gpt2_model): result_model = ModelFactory.get_compiled_model(gpt2_model, block_names, fullgraph=True) assert len(result_model.transformer.h) == 4, "Should still have four blocks" - for i, (original_block, new_block) in enumerate(zip(original_blocks, result_model.transformer.h)): - assert new_block is not original_block, f"Block {i} should be a compiled version" - assert isinstance(new_block, nn.Module), f"Block {i} should be an nn.Module" + for i, (original_block_idx, new_block_idx) in enumerate(zip(original_model.transformer.h, result_model.transformer.h)): + assert result_model.transformer.h[new_block_idx] is not original_model.transformer.h[original_block_idx], f"Block {i} should be a compiled version" + assert isinstance(result_model.transformer.h[new_block_idx], nn.Module), f"Block {i} should be an nn.Module" assert result_model.transformer.wte is original_wte, "Embedding layer should remain unchanged" assert result_model.transformer.lm_head is original_lm_head, "LM head should remain unchanged" assert result_model is gpt2_model, "Should return the same model instance" From ddb249b194e60ae8c277935853a81a3388902038 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 21 Oct 2025 17:31:22 +0200 Subject: [PATCH 59/60] test: Adapt code to latest changes to pass tests --- src/modalities/models/model_factory.py | 6 +++--- tests/fsdp2_parallelization/test_tensor_parallelism.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index d889c213a..7df9ba258 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -683,13 +683,13 @@ def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh) desired_input_layouts=(Replicate(),), ), } - if isinstance(model.transformer.h[0].mlp, SwiGLU): + if isinstance(list(model.transformer.h.values())[0].mlp, SwiGLU): mlp_plan = { "mlp.W": ColwiseParallel(), "mlp.W_2": RowwiseParallel(output_layouts=Shard(1)), "mlp.V": ColwiseParallel(), } - elif isinstance(model.transformer.h[0].mlp, TransformerMLP): + elif isinstance(list(model.transformer.h.values())[0].mlp, TransformerMLP): mlp_plan = { "mlp.c_fc": ColwiseParallel(), "mlp.c_proj": RowwiseParallel(output_layouts=Shard(1)), @@ -701,7 +701,7 @@ def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh) ) transformer_block_tp_plan.update(mlp_plan) - for transformer_block in model.transformer.h: + for transformer_block in model.transformer.h.values(): # override the number of q and kv heads if transformer_block.attn.n_head_q % tp_mesh.size() != 0: raise ValueError( diff --git a/tests/fsdp2_parallelization/test_tensor_parallelism.py b/tests/fsdp2_parallelization/test_tensor_parallelism.py index 449fdb996..ac6554124 100644 --- a/tests/fsdp2_parallelization/test_tensor_parallelism.py +++ b/tests/fsdp2_parallelization/test_tensor_parallelism.py @@ -117,11 +117,11 @@ def _test_tp_sharding_impl( # Ensure models use the correct MLP if activation_type == "gelu": - assert isinstance(fsdp2_model.transformer.h[0].mlp, TransformerMLP) - assert isinstance(tp_model.transformer.h[0].mlp, TransformerMLP) + assert isinstance(fsdp2_model.transformer.h["0"].mlp, TransformerMLP) + assert isinstance(tp_model.transformer.h["0"].mlp, TransformerMLP) elif activation_type == "swiglu": - assert isinstance(fsdp2_model.transformer.h[0].mlp, SwiGLU) - assert isinstance(tp_model.transformer.h[0].mlp, SwiGLU) + assert isinstance(fsdp2_model.transformer.h["0"].mlp, SwiGLU) + assert isinstance(tp_model.transformer.h["0"].mlp, SwiGLU) # Ensure models are sharded correctly assert "tp" in tp_model.transformer.wte.weight.device_mesh.mesh_dim_names From 51b7db4205401735d1f859b9e18312f8a919b928 Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Tue, 21 Oct 2025 17:34:01 +0200 Subject: [PATCH 60/60] test(data): Added tests for distributed multi dim data sampling. --- tests/dataloader/distributed/mocks.py | 42 ++++++++++ .../test_distributed_multidim_dataloader.py | 84 +++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 tests/dataloader/distributed/mocks.py create mode 100644 tests/dataloader/distributed/test_distributed_multidim_dataloader.py diff --git a/tests/dataloader/distributed/mocks.py b/tests/dataloader/distributed/mocks.py new file mode 100644 index 000000000..cc3f044e2 --- /dev/null +++ b/tests/dataloader/distributed/mocks.py @@ -0,0 +1,42 @@ +import os + + +class MultiProcessingCudaEnvMock: + """Context manager to set the CUDA environment for distributed training.""" + + def __init__( + self, + global_rank: int, + local_rank: int, + world_size: int, + rdvz_port: int, + ) -> None: + self.global_rank = global_rank + self.local_rank = local_rank + self.world_size = world_size + self.rdvz_port = rdvz_port + self._original_env: dict[str, str | None] = {} + + def __enter__(self): + # Store original values + for key in ["MASTER_ADDR", "MASTER_PORT", "RANK", "LOCAL_RANK", "WORLD_SIZE"]: + self._original_env[key] = os.environ.get(key) + + # Set new environment variables + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(self.rdvz_port) + os.environ["RANK"] = str(self.global_rank) + os.environ["LOCAL_RANK"] = str(self.local_rank) + os.environ["WORLD_SIZE"] = str(self.world_size) + + # torch.cuda.set_device(local_rank) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Restore original environment variables + for key, value in self._original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value diff --git a/tests/dataloader/distributed/test_distributed_multidim_dataloader.py b/tests/dataloader/distributed/test_distributed_multidim_dataloader.py new file mode 100644 index 000000000..e3546b00c --- /dev/null +++ b/tests/dataloader/distributed/test_distributed_multidim_dataloader.py @@ -0,0 +1,84 @@ +import os +from unittest.mock import MagicMock + +import pytest +from torch.utils.data import BatchSampler + +from modalities.dataloader.dataloader_factory import DataloaderFactory +from modalities.dataloader.sampler_factory import SamplerFactory +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees +from tests.dataloader.distributed.mocks import MultiProcessingCudaEnvMock +from tests.dataloader.dummy_sequential_dataset import TestDataset + + +@pytest.mark.parametrize("world_size, dp_degree", [(4, 2)]) +def test_distributed_multidim_dataloader_produces_same_data_on_connected_non_dp_ranks(world_size: int, dp_degree: int): + batches_on_rank = _build_batch_for_each_rank_combination(world_size, dp_degree) + + for dp_rank in range(dp_degree): + assert all( + batches_on_rank[(dp_rank, 0)] == batches_on_rank[(dp_rank, other_rank)] + for other_rank in range(1, world_size // dp_degree) + ), f"Batches on dp_rank {dp_rank} differ across other ranks." + + +@pytest.mark.parametrize("world_size, dp_degree", [(4, 2)]) +def test_distributed_multidim_dataloader_produces_different_data_on_different_dp_ranks(world_size: int, dp_degree: int): + batches_on_rank = _build_batch_for_each_rank_combination(world_size, dp_degree) + + for dp_rank1 in range(dp_degree): + for dp_rank2 in range(dp_rank1 + 1, dp_degree): + samples_dp_rank1 = sum(batches_on_rank[(dp_rank1, 0)], []) + samples_dp_rank2 = sum(batches_on_rank[(dp_rank2, 0)], []) + assert ( + len(set(samples_dp_rank1).intersection(samples_dp_rank2)) == 0 + ), f"Data samples on different data parallel ranks {dp_rank1} and {dp_rank2} should be disjoint." + + +def _build_batch_for_each_rank_combination(world_size: int, dp_degree: int): + return { + (dp_rank, other_rank): _load_data_for_ranks(dp_rank, other_rank, world_size, dp_degree) + for dp_rank, other_rank in _get_rank_combinations(world_size, dp_degree) + } + + +def _get_rank_combinations(world_size: int, dp_degree: int): + other_degree = world_size // dp_degree + return [(dp_rank, other_rank) for dp_rank in range(dp_degree) for other_rank in range(other_degree)] + + +def _load_data_for_ranks(dp_rank: int, other_rank: int, world_size: int, dp_degree: int): + global_rank = dp_rank * 2 + other_rank + with MultiProcessingCudaEnvMock( + global_rank=global_rank, + local_rank=other_rank, + world_size=world_size, + rdvz_port=22350, + ): + device_mesh = _build_device_mesh_mock(world_size, dp_degree, dp_rank, other_rank) + dataset = TestDataset(8) + sampler = SamplerFactory.create_resumable_distributed_multi_dim_sampler( + dataset=dataset, device_mesh=device_mesh, data_parallel_key=ParallelismDegrees.DP_SHARD + ) + batch_sampler = BatchSampler(sampler, batch_size=2, drop_last=True) + train_dataloader = DataloaderFactory.get_dataloader( + dataloader_tag="train", + dataset=dataset, + batch_sampler=batch_sampler, + collate_fn=None, + num_workers=2, + pin_memory=False, + ) + return [batch.tolist() for batch in train_dataloader] + + +def _build_device_mesh_mock(world_size: int, dp_degree: int, dp_rank: int, other_rank: int): + dp_device_mesh = MagicMock() + dp_device_mesh.size.return_value = dp_degree + dp_device_mesh.get_coordinate.return_value = [dp_rank] + other_device_mesh = MagicMock() + other_degree = world_size // dp_degree + other_device_mesh.size.return_value = int(os.environ["WORLD_SIZE"]) // other_degree + other_device_mesh.get_coordinate.return_value = [other_rank] + device_mesh_mock = {ParallelismDegrees.DP_SHARD.value: dp_device_mesh, "other": other_device_mesh} + return device_mesh_mock