Skip to content

[Feature] Make HeavyBall compatible with FSDP2 (DTensor) #15

@casper-hansen

Description

@casper-hansen

Here is an example traceback of me trying to use heavyball in TorchTitan. It seems that heavyball is not yet compatible with FSDP2 due to some of the utilities operating on tensors instead of dtensors.

Code:

import heavyball
import torch
from torchtitan.config_manager import JobConfig

# consider split between PP and non-PP
def build_optimizers(model_parts, job_config: JobConfig):
    """Wrap one optimizer per model part in an OptimizersContainer which provides a single
    step() and zero_grad() method for all the child optimizers.
    """

    def _build_optimizer(model):
        name = job_config.optimizer.name
        lr = job_config.optimizer.lr
        fused = job_config.optimizer.fused

        # Common parameters for both optimizers
        optimizer_kwargs = {
            "lr": lr,
            "betas": (0.9, 0.95),
            "weight_decay": 0.1,
            "fused": fused,
            "foreach": not fused,
        }
        if name == "Adam":
            # TODO: make the optimizer options configurable by toml/cmd args
            optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
        elif name == "AdamW":
            optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
        else:
            optimizer_kwargs = {
                "lr": lr,
                "betas": (0.9, 0.95),
                "weight_decay": 0.1,
                "foreach": True,
            }
            optimizer_cls = getattr(heavyball, name)
            optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs)

        return optimizer

    class OptimizersContainer:
        """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages"""

        def __init__(self, optimizers):
            self.optimizers = optimizers

        def step(self):
            for optimizer in self.optimizers:
                optimizer.step()

        def zero_grad(self):
            for optimizer in self.optimizers:
                optimizer.zero_grad()

    return OptimizersContainer([_build_optimizer(model) for model in model_parts])

Traceback:

0: [rank0]:Traceback (most recent call last):
0: [rank0]:  File "/workspace/./nlp_train/torchtitan/train.py", line 540, in <module>
0: [rank0]:    main(job_config)
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
0: [rank0]:    return f(*args, **kwargs)
0: [rank0]:           ^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/workspace/./nlp_train/torchtitan/train.py", line 392, in main
0: [rank0]:    optimizers.step()
0: [rank0]:  File "/workspace/nlp_train/torchtitan/optimizer/precond_soap.py", line 55, in step
0: [rank0]:    optimizer.step()
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 140, in wrapper
0: [rank0]:    return func.__get__(opt, opt.__class__)(*args, **kwargs)
0: [rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/optim/optimizer.py", line 494, in wrapper
0: [rank0]:    out = func(*args, **kwargs)
0: [rank0]:          ^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/heavyball/utils.py", line 446, in step
0: [rank0]:    self._step(group)
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/heavyball/precond_schedule_palm_foreach_soap.py", line 63, in _step
0: [rank0]:    update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/heavyball/utils.py", line 354, in update_preconditioner
0: [rank0]:    compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/heavyball/utils.py", line 23, in _fn
0: [rank0]:    return func(*args, **kwargs)
0: [rank0]:           ^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/heavyball/utils.py", line 331, in compute_ggt
0: [rank0]:    GG[idx].lerp_(promote(outer_product), 1 - beta)
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/_compile.py", line 32, in inner
0: [rank0]:    return disable_fn(*args, **kwargs)
0: [rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 721, in _fn
0: [rank0]:    return fn(*args, **kwargs)
0: [rank0]:           ^^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 343, in __torch_dispatch__
0: [rank0]:    return DTensor._op_dispatcher.dispatch(
0: [rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 166, in dispatch
0: [rank0]:    op_info = self.unwrap_to_op_info(op_call, args, kwargs)
0: [rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 371, in unwrap_to_op_info
0: [rank0]:    self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh)
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 470, in _try_replicate_spec_for_scalar_tensor
0: [rank0]:    raise RuntimeError(
0: [rank0]:RuntimeError: aten.lerp_.Scalar: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions