From 6998abcf9ec0a08ce0710c2d853cc574f5d88fce Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Tue, 30 Sep 2025 23:09:16 -0700 Subject: [PATCH 1/2] tp rough draft --- neurons/miner.py | 24 +++++ src/tplr/neurons.py | 210 ++++++++++++++++---------------------------- 2 files changed, 101 insertions(+), 133 deletions(-) diff --git a/neurons/miner.py b/neurons/miner.py index 0011d33d8..b8b992ca4 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -37,6 +37,7 @@ import uvloop from torch.amp.grad_scaler import GradScaler from torch.distributed.tensor import DTensor as DT +from torchtitan.distributed import ParallelDims import tplr from neurons import BaseNode, Trainer @@ -210,6 +211,29 @@ def __init__(self): self.dp_replicate = int(getattr(tt, "dp_replicate", 1)) self.dp_shard = int(getattr(tt, "dp_shard", 1)) + self.parallel_dims = ParallelDims( + dp_replicate=self.dp_replicate, + dp_shard=self.dp_shard, + cp=self.cp_degree, + tp=self.tp_degree, + pp=self.pp_degree, + ep=1, + etp=1, + world_size=self.world_size, + ) + + world_mesh = self.parallel_dims.world_mesh + + if self.parallel_dims.tp_enabled: + from torchtitan.models.llama3.infra.parallelize import apply_tp + apply_tp( + self.model, + world_mesh["tp"], + loss_parallel=False, # Set based on your needs + enable_float8_tensorwise_tp=False, + ) + + # Init compression self.transformer = tplr.compress.ChunkingTransformer( self.model, target_chunk=self.hparams.target_chunk diff --git a/src/tplr/neurons.py b/src/tplr/neurons.py index adf3e8bb7..aafc6ecc1 100644 --- a/src/tplr/neurons.py +++ b/src/tplr/neurons.py @@ -55,11 +55,8 @@ def prepare_gradient_dict(miner: "Miner", step_window: int, null_round: bool = F step_window: Current window number null_round: If True, this is a null/warmup round and error feedback should be cleared """ - + # ------------ helpers ------------ - def ddp_initialized(): - return dist.is_available() and dist.is_initialized() - def is_dtensor(x): try: from torch.distributed._tensor import DTensor # type: ignore[attr-defined] @@ -67,158 +64,105 @@ def is_dtensor(x): return isinstance(x, DTensor) except Exception: return type(x).__name__ in {"DTensor", "DistributedTensor", "DT"} - - def get_mesh_group(x): - if not is_dtensor(x): - return None - mesh = getattr(x, "device_mesh", None) - if mesh is None: - spec = getattr(x, "_spec", None) - mesh = getattr(spec, "mesh", None) - if mesh is not None: - try: - return mesh.get_group() - except Exception: - pass - return dist.group.WORLD if ddp_initialized() else None - - def barrier(group=None): - if ddp_initialized() and group is not None: - dist.barrier(group=group) - + + def get_local_tensor(tensor): + """Get local shard if DTensor, otherwise return as-is""" + if is_dtensor(tensor): + return tensor.to_local() + return tensor + # ------------ start ------------ gradient, xshapes, totalks = {}, {}, {} lr = float(miner.hparams.outer_learning_rate) use_dct = getattr(miner.hparams, "use_dct", False) topk = getattr(miner.hparams, "topk_compression", 32) - - if isinstance(miner.model, torch.nn.parallel.DistributedDataParallel): - model_iterator = miner.model.module.named_parameters() - else: - model_iterator = miner.model.named_parameters() - - # Build params dict once to avoid repeated iteration - params_dict = dict(miner.model.named_parameters()) - - # Batch load all error feedback tensors to GPU - for n in miner.owned_params: - if miner.error_feedback.get(n, None) is not None: - if miner.error_feedback[n].is_cuda: - continue - # Get the device from the corresponding parameter - param = params_dict.get(n) - if param is not None: - miner.error_feedback[n] = miner.error_feedback[n].to( - param.device, non_blocking=True - ) - - for _, (n, p) in enumerate(model_iterator, 1): - owned = n in miner.owned_params - p_is_dt = is_dtensor(p) - g = getattr(p, "grad", None) - g_is_dt = is_dtensor(g) - - # --- 1) Grad full_tensor rendezvous (GFULL) --- - if g_is_dt: - grp_g = get_mesh_group(g) - barrier(grp_g) - assert g is not None - grad_full = g.full_tensor().to(p.device) - barrier(grp_g) - else: - if g is None and not p_is_dt: + + model_iterator = ( + miner.model.module.named_parameters() + if isinstance(miner.model, torch.nn.parallel.DistributedDataParallel) + else miner.model.named_parameters() + ) + + rank = miner.rank if hasattr(miner, 'rank') else 0 + world_size = miner.world_size if hasattr(miner, 'world_size') else 1 + + for idx, (name, param) in enumerate(model_iterator): + param_is_dtensor = is_dtensor(param) + + if not param_is_dtensor: + if idx % world_size != rank: continue - assert g is not None, f"p.grad is None for {n}" - grad_full = g.to(p.device) - - # Non-owners: after participating in grad collective, drop grad and continue. - if not owned: - p.grad = None + + grad = getattr(param, "grad", None) + if grad is None: continue - - # --- 3) Momentum buffer update (owner only) --- - # Handle DTensor error feedback by creating new regular tensor if needed - error_feedback = miner.error_feedback[n] - if error_feedback is None: - error_feedback = torch.zeros_like(grad_full, device=p.device) - elif error_feedback.device != p.device: - # Should already be on GPU from batch load, but handle edge cases - error_feedback = error_feedback.to(p.device) - - # Clear error feedback during null rounds to prevent accumulation of invalid gradients + + local_grad = get_local_tensor(grad) + local_param = get_local_tensor(param) + + if name not in miner.error_feedback or miner.error_feedback[name] is None: + miner.error_feedback[name] = torch.zeros_like( + local_grad, device=local_param.device + ) + + error_feedback = miner.error_feedback[name] + + if error_feedback.shape != local_grad.shape: + tplr.logger.warning( + f"Error feedback shape mismatch for {name}: " + f"{error_feedback.shape} vs {local_grad.shape}, reinitializing" + ) + error_feedback = torch.zeros_like(local_grad, device=local_param.device) + miner.error_feedback[name] = error_feedback + + if error_feedback.device != local_param.device: + error_feedback = error_feedback.to(local_param.device) + miner.error_feedback[name] = error_feedback + if null_round: error_feedback.zero_() else: error_feedback.mul_(miner.hparams.momentum_decay) - error_feedback.add_(grad_full, alpha=lr) - - # --- 4) Encode & compress (owner only) --- + error_feedback.add_(local_grad, alpha=lr) + encoded = miner.transformer.encode(error_feedback, use_dct=use_dct) - idxs, vals, xshape, totalk, quant_params = miner.compressor.compress( encoded, topk ) del encoded - - # --- 5) Decompress reference (owner only) --- - # Pass p directly - decompress only uses p.device and p.dtype + decompressed = miner.compressor.decompress( - p, idxs, vals, xshape, totalk, quant_params + local_param, idxs, vals, xshape, totalk, quant_params ) - - # --- 6) Decode & error-feedback update (owner only) --- + transmit_grad = miner.transformer.decode(decompressed, use_dct=use_dct) del decompressed error_feedback.sub_(transmit_grad) - # Keep error feedback on GPU for now, batch offload later - miner.error_feedback[n] = error_feedback - del transmit_grad, error_feedback - - # --- 7) Pack outputs (move compressed artifacts to CPU asynchronously) --- - # Using non_blocking=True for async D2H transfers when CUDA is available - if isinstance(idxs, torch.Tensor): - if torch.cuda.is_available(): - cpu_idxs = torch.empty_like(idxs, device="cpu", pin_memory=True) - cpu_idxs.copy_(idxs, non_blocking=True) - gradient[n + "idxs"] = cpu_idxs - else: - gradient[n + "idxs"] = idxs.cpu() - else: - gradient[n + "idxs"] = idxs - - if isinstance(vals, torch.Tensor): - if torch.cuda.is_available(): - cpu_vals = torch.empty_like(vals, device="cpu", pin_memory=True) - cpu_vals.copy_(vals, non_blocking=True) - gradient[n + "vals"] = cpu_vals - else: - gradient[n + "vals"] = vals.cpu() - else: - gradient[n + "vals"] = vals - gradient[n + "quant_params"] = quant_params - xshapes[n] = xshape - totalks[n] = totalk - - # Clear per-param grad - p.grad = None - - # Batch offload all error feedback tensors to CPU with pinned memory - for name in miner.error_feedback: - if ( - miner.error_feedback[name] is not None - and miner.error_feedback[name].is_cuda - ): - # Copy to the pre-allocated pinned buffer - miner.error_feedback_cpu_buffers[name].copy_( - miner.error_feedback[name], non_blocking=True - ) - miner.error_feedback[name] = miner.error_feedback_cpu_buffers[name] - - # Single synchronization at the end for all async operations + del transmit_grad + + param_key = name if not param_is_dtensor else f"{name}_tp{rank}" + + gradient[param_key + "idxs"] = ( + idxs.cpu() if isinstance(idxs, torch.Tensor) else idxs + ) + gradient[param_key + "vals"] = ( + vals.cpu() if isinstance(vals, torch.Tensor) else vals + ) + gradient[param_key + "quant_params"] = quant_params + xshapes[param_key] = xshape + totalks[param_key] = totalk + + param.grad = None + if torch.cuda.is_available(): - torch.cuda.synchronize() torch.cuda.empty_cache() - gradient["metadata"] = {"window": step_window} + + gradient["metadata"] = { + "window": step_window, + "rank": rank, + "world_size": world_size, + } + return gradient, xshapes, totalks From f0982904587c5273a34e0f25877b003386c48a68 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Tue, 30 Sep 2025 23:10:32 -0700 Subject: [PATCH 2/2] remove testing args --- neurons/miner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/neurons/miner.py b/neurons/miner.py index b8b992ca4..a36dc15b8 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -229,8 +229,7 @@ def __init__(self): apply_tp( self.model, world_mesh["tp"], - loss_parallel=False, # Set based on your needs - enable_float8_tensorwise_tp=False, + loss_parallel=False, )