-
Notifications
You must be signed in to change notification settings - Fork 25
Open
Description
Here is an example traceback of me trying to use heavyball in TorchTitan. It seems that heavyball is not yet compatible with FSDP2 due to some of the utilities operating on tensors instead of dtensors.
Code:
import heavyball
import torch
from torchtitan.config_manager import JobConfig
# consider split between PP and non-PP
def build_optimizers(model_parts, job_config: JobConfig):
"""Wrap one optimizer per model part in an OptimizersContainer which provides a single
step() and zero_grad() method for all the child optimizers.
"""
def _build_optimizer(model):
name = job_config.optimizer.name
lr = job_config.optimizer.lr
fused = job_config.optimizer.fused
# Common parameters for both optimizers
optimizer_kwargs = {
"lr": lr,
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": fused,
"foreach": not fused,
}
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
elif name == "AdamW":
optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
else:
optimizer_kwargs = {
"lr": lr,
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"foreach": True,
}
optimizer_cls = getattr(heavyball, name)
optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs)
return optimizer
class OptimizersContainer:
"""Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages"""
def __init__(self, optimizers):
self.optimizers = optimizers
def step(self):
for optimizer in self.optimizers:
optimizer.step()
def zero_grad(self):
for optimizer in self.optimizers:
optimizer.zero_grad()
return OptimizersContainer([_build_optimizer(model) for model in model_parts])Traceback:
0: [rank0]:Traceback (most recent call last):
0: [rank0]: File "/workspace/./nlp_train/torchtitan/train.py", line 540, in <module>
0: [rank0]: main(job_config)
0: [rank0]: File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
0: [rank0]: return f(*args, **kwargs)
0: [rank0]: ^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/workspace/./nlp_train/torchtitan/train.py", line 392, in main
0: [rank0]: optimizers.step()
0: [rank0]: File "/workspace/nlp_train/torchtitan/optimizer/precond_soap.py", line 55, in step
0: [rank0]: optimizer.step()
0: [rank0]: File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 140, in wrapper
0: [rank0]: return func.__get__(opt, opt.__class__)(*args, **kwargs)
0: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/optim/optimizer.py", line 494, in wrapper
0: [rank0]: out = func(*args, **kwargs)
0: [rank0]: ^^^^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/heavyball/utils.py", line 446, in step
0: [rank0]: self._step(group)
0: [rank0]: File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/heavyball/precond_schedule_palm_foreach_soap.py", line 63, in _step
0: [rank0]: update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
0: [rank0]: File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/heavyball/utils.py", line 354, in update_preconditioner
0: [rank0]: compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
0: [rank0]: File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/heavyball/utils.py", line 23, in _fn
0: [rank0]: return func(*args, **kwargs)
0: [rank0]: ^^^^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/heavyball/utils.py", line 331, in compute_ggt
0: [rank0]: GG[idx].lerp_(promote(outer_product), 1 - beta)
0: [rank0]: File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/_compile.py", line 32, in inner
0: [rank0]: return disable_fn(*args, **kwargs)
0: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 721, in _fn
0: [rank0]: return fn(*args, **kwargs)
0: [rank0]: ^^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 343, in __torch_dispatch__
0: [rank0]: return DTensor._op_dispatcher.dispatch(
0: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 166, in dispatch
0: [rank0]: op_info = self.unwrap_to_op_info(op_call, args, kwargs)
0: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 371, in unwrap_to_op_info
0: [rank0]: self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh)
0: [rank0]: File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 470, in _try_replicate_spec_for_scalar_tensor
0: [rank0]: raise RuntimeError(
0: [rank0]:RuntimeError: aten.lerp_.Scalar: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!Metadata
Metadata
Assignees
Labels
No labels