Skip to content

Conversation

BlueCrescent
Copy link
Collaborator

@BlueCrescent BlueCrescent commented Sep 17, 2025

What does this PR do?

  • Issue: Dicts, as used in modalities as model/loss inputs and outputs, are not supported in the torch pp code.
  • Issue: Loss accumulation (for logging) needs to consider pp world size instead of total world size
    • Solution:
      • For the averaged loss we only increment the batch count on the last stages of a PP schedule.
      • For the current batch loss we added the pp world size, computed from the device mesh, to the Trainer class.
  • Issue: Due to a Bug in PyTorch we cannot evaluate before training begin.
  • Issue: Normal distributed dataloader/sampler use would loose batches
    • Solution: Use (Resumable)DistributedMultiDimSamplerConfig
    • TODO: Implement DistributedMultiDimSamplerConfig for eval dataloader
    • TODO??: Remove data_parallel_key param from SamplerFactory.create_resumable_distributed_multi_dim_sampler()
      • Instead use BOTH ParallelismDegrees.DP_REPLICATE and ParallelismDegrees.DP_SHARD (multiplied)
  • Issue: Gradient clipping needs to sync over all stages
  • Issue: Instead of calling forward() and backward() on the model and executing the loss_fct(), PP requires to call step() on the pipeline schedule.
    • Solution: Integrated scheduled pipeline into Trainer and Evaluator
  • Issue: Want to run evaluation with torch.no_grad() and without doing a backwards() pass in the pipeline schedule step().
    • Solution: This is only supported from PyTorch 2.9 (current nightly) on. There the pipeline schedule gets a eval() method to be used instead of step().
    • ⚠️Warning⚠️: Consequently, we can only use PP in training with PyTorch 2.9 Nightly installed.
  • Note: Weight tying is currently not supported and probably not possible with PP.
  • Question: Do we need to consider something regarding model init so that all corresponding model copies are initialized the same?
    • I.e.: How can we assure that parallel stages are initialized the same?
    • Other way around: Will stages containing similar model parts be initialized the same?
    • Answer: Since we first build that PP stages and then the FSDP2 parallelization and then init, this is fine.
    • Follow up question: What happens, if we use full replica data parallelization?
  • Issue: Want to compare non-PP forward pass with PP forward pass in unit test
    • Goal: Compare losses
    • Solution: Added a non-PP config used to compute a comparison loss
    • Issue: Need non-PP forward pass loss on all ranks that contain a last stage of the PP forward pass.
      • Solution: Run a normal FSDP2 forward pass on all ranks and ignore the output on non-last stage ranks.
      • Remark: Running the FSDP2 pass only on the last stage ranks led to hanging. This might be a known issue with FSDP2.
    • Issue: Need the model to effectively compute the same forward pass.
      • Solution:
        • In the test config, run model model initialization before sharding/staging (don't do this in a real training!).
        • Fix torch seed before each initialization.
        • Fix torch seed before generating each input sequence. Note that we thus have identical data parallel batches.
    • TODO: Numerical instability can still be observed and needs to be investigated further (in FSDP only runs as well).
  • TP + PP
    • Issue: What is the correct sharding/staging order?
      • Solution: 1. PP 2. TP 3. FSDP
    • Issue: TP initialization modifies the layers of the model and runs into problems if those have been deleted for PP.
      • Solution: Adapted TP code to be able to handle missing/None layers.
    • Issue: PP + TP test hangs on first stages, first microbatches at the start of attention.
      • Solution: The test used sequence length 255. It seems, this not being divisible by two caused issues with TP's sequence parallelism.
      • ⚠️Warning⚠️: This needs to be kept in mind when using PP + TP.
  • TODO: Wherever "dist.get_world_size()" is used, check if we should use number of data parallel ranks instead.
  • TODO: MFU is 0 on dgx2, but also without PP
  • TODO: Integrate and test other pipeline schedules.
  • TODO: Check if checkpointing still works.

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements pipeline parallelism support in the Modalities framework by addressing PyTorch's lack of native support for dict-based model inputs/outputs and implementing proper loss accumulation, gradient clipping, and data loading for pipeline parallel training.

Key changes include:

  • Overloaded model forward() methods to support both dict and tensor inputs for pipeline parallelism compatibility
  • Updated gradient clippers to sync across pipeline stages and support device mesh configurations
  • Modified trainer and evaluator to integrate with pipeline schedules instead of direct model calls
  • Added proper loss accumulation using data parallel world size instead of total world size

Reviewed Changes

Copilot reviewed 17 out of 17 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py Adds comprehensive test comparing PP vs non-PP forward passes with loss validation
tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml New FSDP2 config for PP testing without pipeline stages
src/modalities/models/gpt2/gpt2_model.py Overloads forward() method to accept both dict and tensor inputs for PP compatibility
src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py Implements cross-stage gradient norm synchronization for pipeline parallelism
src/modalities/trainer.py Integrates pipeline schedule execution and fixes loss accumulation for PP
src/modalities/loss_functions.py Overloads loss function to handle both InferenceResultBatch and tensor inputs
src/modalities/evaluator.py Adds pipeline schedule support to evaluation process
src/modalities/models/model_factory.py Fixes tensor parallelism to handle missing model layers in PP stages

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

def __call__(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
...

def __call__(self, *args, **kwargs) -> torch.Tensor:
Copy link

Copilot AI Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using *args and **kwargs instead of proper overloads makes the API less type-safe and harder to understand. Consider implementing proper method overloading with specific parameter types.

Copilot uses AI. Check for mistakes.

@rrutmann rrutmann self-assigned this Sep 22, 2025
@rrutmann rrutmann requested a review from le1nux September 22, 2025 16:02
@le1nux le1nux added this to the 100B milestone Oct 8, 2025
@le1nux le1nux linked an issue Oct 8, 2025 that may be closed by this pull request
Copy link
Member

@le1nux le1nux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work with the PP integration! Functionality-wise everything looked correct (did not check the tests yet).

Regarding the integration from a architectural perspetive, I left a couple of comments. I think we should do some refactorings here.

gradient_clipper=components.gradient_clipper,
global_num_tokens_per_train_step=global_num_tokens_per_train_step,
mfu_calculator=components.mfu_calculator,
num_pipeline_parallel_ranks=num_pipeline_parallel_ranks,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer if we kept the Trainer high level and abstract away specifics like PP.

checkpointing_interval_in_steps=components.settings.intervals.checkpointing_interval_in_steps,
evaluation_interval_in_steps=components.settings.intervals.evaluation_interval_in_steps,
training_log_interval_in_steps=components.settings.intervals.training_log_interval_in_steps,
scheduled_pipeline=components.scheduled_pipeline if components.scheduled_pipeline else None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same point as for the trainer. Could we wrap the scheduled pipeline instead and use the existing model interfaces?

Comment on lines +189 to +198
pp_mesh = get_mesh_for_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP)
if pp_mesh is not None:
if math.isinf(norm_type):
dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group())
else:
total_norm **= norm_type
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
total_norm **= 1.0 / norm_type

torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should have a test for this

Comment on lines +240 to +249

pp_mesh = get_mesh_for_parallelism_method(
device_mesh=self.device_mesh, parallelism_method=ParallelismDegrees.PP
)
if pp_mesh is not None:
if math.isinf(self.norm_type.value):
dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group())
else:
total_norm **= self.norm_type.value
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated code

pp_schedule_name: gpipe
batch_size: ${settings.step_profile.local_train_micro_batch_size}
microbatch_size: 1
microbatch_size: 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we reference this from the top?

@@ -1,12 +1,10 @@
[project]
name = "modalities"
version = "0.3.2"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we remove this? Our testing is always against 3.10 and 3.11. Do we need a more recent python version?

...

def __call__(self, *args, **kwargs) -> torch.Tensor:
labels, lm_logits = self._parse_arguments(args, kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be improved from a software engineering point of view

Comment on lines +54 to +88
def _parse_arguments(
self,
args: list[torch.Tensor] | list[InferenceResultBatch],
kwargs: dict[str, torch.Tensor] | dict[str, InferenceResultBatch],
) -> tuple[torch.Tensor, torch.Tensor]:
if len(args) == 1 and isinstance(args[0], InferenceResultBatch):
forward_batch = args[0]
labels = forward_batch.get_targets(self.target_key)
lm_logits = forward_batch.get_predictions(self.prediction_key)
elif "forward_batch" in kwargs and isinstance(kwargs["forward_batch"], InferenceResultBatch):
forward_batch = kwargs["forward_batch"]
labels = forward_batch.get_targets(self.target_key)
lm_logits = forward_batch.get_predictions(self.prediction_key)
elif len(args) == 2 and all(isinstance(arg, torch.Tensor) for arg in args):
lm_logits, labels = args
elif (
"outputs" in kwargs
and "targets" in kwargs
and isinstance(kwargs["outputs"], torch.Tensor)
and isinstance(kwargs["targets"], torch.Tensor)
):
lm_logits = kwargs["outputs"]
labels = kwargs["targets"]
elif (
len(args) == 1
and "targets" in kwargs
and isinstance(args[0], torch.Tensor)
and isinstance(kwargs["targets"], torch.Tensor)
):
lm_logits = args[0]
labels = kwargs["targets"]
else:
raise TypeError("Invalid arguments for CLMCrossEntropyLoss.__call__")
return labels, lm_logits

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idea: What about defining a new component "pp-loss", which takes a normal loss function and handles the PP-specific part?

Generally, I think this parsing function could be improved.

device_mesh: PydanticDeviceMeshIFType | None = None


class DummyGradientClipperConfig(BaseModel):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this class now?

with torch.no_grad():
result_batch = model_predict_batch(model=model, batch=batch)
loss = loss_fun(result_batch)
if scheduled_pipeline is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically code duplication from the trainer.
Also not a big fan of passing the scheduled_pipeline in here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Epic: Pipeline Parallelism

3 participants