Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import uvloop
from torch.amp.grad_scaler import GradScaler
from torch.distributed.tensor import DTensor as DT
from torchtitan.distributed import ParallelDims

import tplr
from neurons import BaseNode, Trainer
Expand Down Expand Up @@ -210,6 +211,28 @@ def __init__(self):
self.dp_replicate = int(getattr(tt, "dp_replicate", 1))
self.dp_shard = int(getattr(tt, "dp_shard", 1))

self.parallel_dims = ParallelDims(
dp_replicate=self.dp_replicate,
dp_shard=self.dp_shard,
cp=self.cp_degree,
tp=self.tp_degree,
pp=self.pp_degree,
ep=1,
etp=1,
world_size=self.world_size,
)

world_mesh = self.parallel_dims.world_mesh

if self.parallel_dims.tp_enabled:
from torchtitan.models.llama3.infra.parallelize import apply_tp
apply_tp(
self.model,
world_mesh["tp"],
loss_parallel=False,
)


# Init compression
self.transformer = tplr.compress.ChunkingTransformer(
self.model, target_chunk=self.hparams.target_chunk
Expand Down
210 changes: 77 additions & 133 deletions src/tplr/neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,170 +55,114 @@ def prepare_gradient_dict(miner: "Miner", step_window: int, null_round: bool = F
step_window: Current window number
null_round: If True, this is a null/warmup round and error feedback should be cleared
"""

# ------------ helpers ------------
def ddp_initialized():
return dist.is_available() and dist.is_initialized()

def is_dtensor(x):
try:
from torch.distributed._tensor import DTensor # type: ignore[attr-defined]

return isinstance(x, DTensor)
except Exception:
return type(x).__name__ in {"DTensor", "DistributedTensor", "DT"}

def get_mesh_group(x):
if not is_dtensor(x):
return None
mesh = getattr(x, "device_mesh", None)
if mesh is None:
spec = getattr(x, "_spec", None)
mesh = getattr(spec, "mesh", None)
if mesh is not None:
try:
return mesh.get_group()
except Exception:
pass
return dist.group.WORLD if ddp_initialized() else None

def barrier(group=None):
if ddp_initialized() and group is not None:
dist.barrier(group=group)


def get_local_tensor(tensor):
"""Get local shard if DTensor, otherwise return as-is"""
if is_dtensor(tensor):
return tensor.to_local()
return tensor

# ------------ start ------------
gradient, xshapes, totalks = {}, {}, {}
lr = float(miner.hparams.outer_learning_rate)
use_dct = getattr(miner.hparams, "use_dct", False)
topk = getattr(miner.hparams, "topk_compression", 32)

if isinstance(miner.model, torch.nn.parallel.DistributedDataParallel):
model_iterator = miner.model.module.named_parameters()
else:
model_iterator = miner.model.named_parameters()

# Build params dict once to avoid repeated iteration
params_dict = dict(miner.model.named_parameters())

# Batch load all error feedback tensors to GPU
for n in miner.owned_params:
if miner.error_feedback.get(n, None) is not None:
if miner.error_feedback[n].is_cuda:
continue
# Get the device from the corresponding parameter
param = params_dict.get(n)
if param is not None:
miner.error_feedback[n] = miner.error_feedback[n].to(
param.device, non_blocking=True
)

for _, (n, p) in enumerate(model_iterator, 1):
owned = n in miner.owned_params
p_is_dt = is_dtensor(p)
g = getattr(p, "grad", None)
g_is_dt = is_dtensor(g)

# --- 1) Grad full_tensor rendezvous (GFULL) ---
if g_is_dt:
grp_g = get_mesh_group(g)
barrier(grp_g)
assert g is not None
grad_full = g.full_tensor().to(p.device)
barrier(grp_g)
else:
if g is None and not p_is_dt:

model_iterator = (
miner.model.module.named_parameters()
if isinstance(miner.model, torch.nn.parallel.DistributedDataParallel)
else miner.model.named_parameters()
)

rank = miner.rank if hasattr(miner, 'rank') else 0
world_size = miner.world_size if hasattr(miner, 'world_size') else 1

for idx, (name, param) in enumerate(model_iterator):
param_is_dtensor = is_dtensor(param)

if not param_is_dtensor:
if idx % world_size != rank:
continue
assert g is not None, f"p.grad is None for {n}"
grad_full = g.to(p.device)

# Non-owners: after participating in grad collective, drop grad and continue.
if not owned:
p.grad = None

grad = getattr(param, "grad", None)
if grad is None:
continue

# --- 3) Momentum buffer update (owner only) ---
# Handle DTensor error feedback by creating new regular tensor if needed
error_feedback = miner.error_feedback[n]
if error_feedback is None:
error_feedback = torch.zeros_like(grad_full, device=p.device)
elif error_feedback.device != p.device:
# Should already be on GPU from batch load, but handle edge cases
error_feedback = error_feedback.to(p.device)

# Clear error feedback during null rounds to prevent accumulation of invalid gradients

local_grad = get_local_tensor(grad)
local_param = get_local_tensor(param)

if name not in miner.error_feedback or miner.error_feedback[name] is None:
miner.error_feedback[name] = torch.zeros_like(
local_grad, device=local_param.device
)

error_feedback = miner.error_feedback[name]

if error_feedback.shape != local_grad.shape:
tplr.logger.warning(
f"Error feedback shape mismatch for {name}: "
f"{error_feedback.shape} vs {local_grad.shape}, reinitializing"
)
error_feedback = torch.zeros_like(local_grad, device=local_param.device)
miner.error_feedback[name] = error_feedback

if error_feedback.device != local_param.device:
error_feedback = error_feedback.to(local_param.device)
miner.error_feedback[name] = error_feedback

if null_round:
error_feedback.zero_()
else:
error_feedback.mul_(miner.hparams.momentum_decay)
error_feedback.add_(grad_full, alpha=lr)

# --- 4) Encode & compress (owner only) ---
error_feedback.add_(local_grad, alpha=lr)

encoded = miner.transformer.encode(error_feedback, use_dct=use_dct)

idxs, vals, xshape, totalk, quant_params = miner.compressor.compress(
encoded, topk
)
del encoded

# --- 5) Decompress reference (owner only) ---
# Pass p directly - decompress only uses p.device and p.dtype

decompressed = miner.compressor.decompress(
p, idxs, vals, xshape, totalk, quant_params
local_param, idxs, vals, xshape, totalk, quant_params
)

# --- 6) Decode & error-feedback update (owner only) ---

transmit_grad = miner.transformer.decode(decompressed, use_dct=use_dct)
del decompressed
error_feedback.sub_(transmit_grad)
# Keep error feedback on GPU for now, batch offload later
miner.error_feedback[n] = error_feedback
del transmit_grad, error_feedback

# --- 7) Pack outputs (move compressed artifacts to CPU asynchronously) ---
# Using non_blocking=True for async D2H transfers when CUDA is available
if isinstance(idxs, torch.Tensor):
if torch.cuda.is_available():
cpu_idxs = torch.empty_like(idxs, device="cpu", pin_memory=True)
cpu_idxs.copy_(idxs, non_blocking=True)
gradient[n + "idxs"] = cpu_idxs
else:
gradient[n + "idxs"] = idxs.cpu()
else:
gradient[n + "idxs"] = idxs

if isinstance(vals, torch.Tensor):
if torch.cuda.is_available():
cpu_vals = torch.empty_like(vals, device="cpu", pin_memory=True)
cpu_vals.copy_(vals, non_blocking=True)
gradient[n + "vals"] = cpu_vals
else:
gradient[n + "vals"] = vals.cpu()
else:
gradient[n + "vals"] = vals
gradient[n + "quant_params"] = quant_params
xshapes[n] = xshape
totalks[n] = totalk

# Clear per-param grad
p.grad = None

# Batch offload all error feedback tensors to CPU with pinned memory
for name in miner.error_feedback:
if (
miner.error_feedback[name] is not None
and miner.error_feedback[name].is_cuda
):
# Copy to the pre-allocated pinned buffer
miner.error_feedback_cpu_buffers[name].copy_(
miner.error_feedback[name], non_blocking=True
)
miner.error_feedback[name] = miner.error_feedback_cpu_buffers[name]

# Single synchronization at the end for all async operations
del transmit_grad

param_key = name if not param_is_dtensor else f"{name}_tp{rank}"

gradient[param_key + "idxs"] = (
idxs.cpu() if isinstance(idxs, torch.Tensor) else idxs
)
gradient[param_key + "vals"] = (
vals.cpu() if isinstance(vals, torch.Tensor) else vals
)
gradient[param_key + "quant_params"] = quant_params
xshapes[param_key] = xshape
totalks[param_key] = totalk

param.grad = None

if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
gradient["metadata"] = {"window": step_window}

gradient["metadata"] = {
"window": step_window,
"rank": rank,
"world_size": world_size,
}

return gradient, xshapes, totalks


Expand Down
Loading