-
Notifications
You must be signed in to change notification settings - Fork 12
Pipeline parallelism continued #399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pipeline_parallelism
Are you sure you want to change the base?
Conversation
… classes for pipeline parallelism.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements pipeline parallelism support in the Modalities framework by addressing PyTorch's lack of native support for dict-based model inputs/outputs and implementing proper loss accumulation, gradient clipping, and data loading for pipeline parallel training.
Key changes include:
- Overloaded model forward() methods to support both dict and tensor inputs for pipeline parallelism compatibility
- Updated gradient clippers to sync across pipeline stages and support device mesh configurations
- Modified trainer and evaluator to integrate with pipeline schedules instead of direct model calls
- Added proper loss accumulation using data parallel world size instead of total world size
Reviewed Changes
Copilot reviewed 17 out of 17 changed files in this pull request and generated 6 comments.
Show a summary per file
File | Description |
---|---|
tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py | Adds comprehensive test comparing PP vs non-PP forward passes with loss validation |
tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml | New FSDP2 config for PP testing without pipeline stages |
src/modalities/models/gpt2/gpt2_model.py | Overloads forward() method to accept both dict and tensor inputs for PP compatibility |
src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py | Implements cross-stage gradient norm synchronization for pipeline parallelism |
src/modalities/trainer.py | Integrates pipeline schedule execution and fixes loss accumulation for PP |
src/modalities/loss_functions.py | Overloads loss function to handle both InferenceResultBatch and tensor inputs |
src/modalities/evaluator.py | Adds pipeline schedule support to evaluation process |
src/modalities/models/model_factory.py | Fixes tensor parallelism to handle missing model layers in PP stages |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py
Outdated
Show resolved
Hide resolved
src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py
Outdated
Show resolved
Hide resolved
def __call__(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | ||
... | ||
|
||
def __call__(self, *args, **kwargs) -> torch.Tensor: |
Copilot
AI
Sep 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using *args and **kwargs instead of proper overloads makes the API less type-safe and harder to understand. Consider implementing proper method overloading with specific parameter types.
Copilot uses AI. Check for mistakes.
…es/modalities into pipeline_parallelism_fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work with the PP integration! Functionality-wise everything looked correct (did not check the tests yet).
Regarding the integration from a architectural perspetive, I left a couple of comments. I think we should do some refactorings here.
gradient_clipper=components.gradient_clipper, | ||
global_num_tokens_per_train_step=global_num_tokens_per_train_step, | ||
mfu_calculator=components.mfu_calculator, | ||
num_pipeline_parallel_ranks=num_pipeline_parallel_ranks, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer if we kept the Trainer
high level and abstract away specifics like PP.
checkpointing_interval_in_steps=components.settings.intervals.checkpointing_interval_in_steps, | ||
evaluation_interval_in_steps=components.settings.intervals.evaluation_interval_in_steps, | ||
training_log_interval_in_steps=components.settings.intervals.training_log_interval_in_steps, | ||
scheduled_pipeline=components.scheduled_pipeline if components.scheduled_pipeline else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same point as for the trainer. Could we wrap the scheduled pipeline instead and use the existing model interfaces?
pp_mesh = get_mesh_for_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) | ||
if pp_mesh is not None: | ||
if math.isinf(norm_type): | ||
dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) | ||
else: | ||
total_norm **= norm_type | ||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) | ||
total_norm **= 1.0 / norm_type | ||
|
||
torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should have a test for this
|
||
pp_mesh = get_mesh_for_parallelism_method( | ||
device_mesh=self.device_mesh, parallelism_method=ParallelismDegrees.PP | ||
) | ||
if pp_mesh is not None: | ||
if math.isinf(self.norm_type.value): | ||
dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) | ||
else: | ||
total_norm **= self.norm_type.value | ||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicated code
pp_schedule_name: gpipe | ||
batch_size: ${settings.step_profile.local_train_micro_batch_size} | ||
microbatch_size: 1 | ||
microbatch_size: 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we reference this from the top?
@@ -1,12 +1,10 @@ | |||
[project] | |||
name = "modalities" | |||
version = "0.3.2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did we remove this? Our testing is always against 3.10 and 3.11. Do we need a more recent python version?
... | ||
|
||
def __call__(self, *args, **kwargs) -> torch.Tensor: | ||
labels, lm_logits = self._parse_arguments(args, kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be improved from a software engineering point of view
def _parse_arguments( | ||
self, | ||
args: list[torch.Tensor] | list[InferenceResultBatch], | ||
kwargs: dict[str, torch.Tensor] | dict[str, InferenceResultBatch], | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
if len(args) == 1 and isinstance(args[0], InferenceResultBatch): | ||
forward_batch = args[0] | ||
labels = forward_batch.get_targets(self.target_key) | ||
lm_logits = forward_batch.get_predictions(self.prediction_key) | ||
elif "forward_batch" in kwargs and isinstance(kwargs["forward_batch"], InferenceResultBatch): | ||
forward_batch = kwargs["forward_batch"] | ||
labels = forward_batch.get_targets(self.target_key) | ||
lm_logits = forward_batch.get_predictions(self.prediction_key) | ||
elif len(args) == 2 and all(isinstance(arg, torch.Tensor) for arg in args): | ||
lm_logits, labels = args | ||
elif ( | ||
"outputs" in kwargs | ||
and "targets" in kwargs | ||
and isinstance(kwargs["outputs"], torch.Tensor) | ||
and isinstance(kwargs["targets"], torch.Tensor) | ||
): | ||
lm_logits = kwargs["outputs"] | ||
labels = kwargs["targets"] | ||
elif ( | ||
len(args) == 1 | ||
and "targets" in kwargs | ||
and isinstance(args[0], torch.Tensor) | ||
and isinstance(kwargs["targets"], torch.Tensor) | ||
): | ||
lm_logits = args[0] | ||
labels = kwargs["targets"] | ||
else: | ||
raise TypeError("Invalid arguments for CLMCrossEntropyLoss.__call__") | ||
return labels, lm_logits | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idea: What about defining a new component "pp-loss", which takes a normal loss function and handles the PP-specific part?
Generally, I think this parsing function could be improved.
device_mesh: PydanticDeviceMeshIFType | None = None | ||
|
||
|
||
class DummyGradientClipperConfig(BaseModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we remove this class now?
with torch.no_grad(): | ||
result_batch = model_predict_batch(model=model, batch=batch) | ||
loss = loss_fun(result_batch) | ||
if scheduled_pipeline is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
basically code duplication from the trainer.
Also not a big fan of passing the scheduled_pipeline in here.
What does this PR do?
Checklist before submitting final PR
python tests/tests.py
)CHANGELOG_DEV.md
)