From a4c9da2c011487df2131e56db64d08833cd759f0 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Thu, 20 Nov 2025 16:09:53 -0800 Subject: [PATCH 01/11] Added first non-working version --- flashinfer/comm/__init__.py | 13 + flashinfer/comm/allreduce.py | 809 +++++++++++++++++++++++++++++++++++ 2 files changed, 822 insertions(+) create mode 100644 flashinfer/comm/allreduce.py diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index f7ae3754ac..b0e9dfd0a4 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -39,4 +39,17 @@ from .vllm_ar import register_buffer as vllm_register_buffer from .vllm_ar import register_graph_buffers as vllm_register_graph_buffers +# Unified AllReduce Fusion API +from .allreduce import AllReduceFusionContext as AllReduceFusionContext +from .allreduce import AllReduceFusionWorkspace as AllReduceFusionWorkspace +from .allreduce import MNNVLAllReduceFusionWorkspace as MNNVLAllReduceFusionWorkspace +from .allreduce import TRTLLMAllReduceFusionWorkspace as TRTLLMAllReduceFusionWorkspace +from .allreduce import allreduce_fusion as allreduce_fusion +from .allreduce import ( + create_allreduce_fusion_workspace as create_allreduce_fusion_workspace, +) +from .allreduce import ( + destroy_allreduce_fusion_workspace as destroy_allreduce_fusion_workspace, +) + # from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py new file mode 100644 index 0000000000..a8f6498c0c --- /dev/null +++ b/flashinfer/comm/allreduce.py @@ -0,0 +1,809 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +Unified AllReduce Fusion API + +This module provides a unified interface for AllReduce + RMSNorm fusion operations +across different backends (TensorRT-LLM, MNNVL). + +Example usage: + >>> # Auto-select best backend based on topology + >>> workspace = create_allreduce_fusion_workspace( + ... backend="auto", + ... world_size=8, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="single_node" + ... ) + >>> + >>> # Perform AllReduce + RMSNorm fusion + >>> prenorm = torch.empty_like(hidden_states) + >>> normed = torch.empty_like(hidden_states) + >>> output = allreduce_fusion( + ... input=hidden_states, + ... workspace=workspace, + ... launch_with_pdl=True, + ... residual_out=prenorm, + ... norm_out=normed, + ... residual_in=residual, + ... rms_gamma=norm_weight + ... ) + >>> + >>> destroy_allreduce_fusion_workspace(workspace) +""" + +from typing import Union, Literal, Optional +from abc import ABC, abstractmethod + +import torch + +from ..utils import backend_requirement, supported_compute_capability + + +# ============================================================================ +# WORKSPACE BASE CLASS +# ============================================================================ + + +class AllReduceFusionWorkspace(ABC): + """Base class for AllReduce fusion workspaces.""" + + def __init__(self, world_size: int, rank: int): + self.world_size = world_size + self.rank = rank + + @property + @abstractmethod + def backend(self) -> str: + """Return backend name.""" + pass + + +class TRTLLMAllReduceFusionWorkspace(AllReduceFusionWorkspace): + """TensorRT-LLM workspace for AllReduce fusion.""" + + def __init__(self, world_size: int, rank: int, workspace_ptrs, metadata): + super().__init__(world_size, rank) + self.workspace_ptrs = workspace_ptrs + self.metadata = metadata + + @property + def backend(self) -> str: + return "trtllm" + + +class MNNVLAllReduceFusionWorkspace(AllReduceFusionWorkspace): + """MNNVL workspace for AllReduce fusion.""" + + def __init__( + self, + world_size: int, + rank: int, + multicast_buffer_ptr: int, + buffer_ptrs_dev: int, + unicast_ptr: int, + buffer_M: int, + buffer_flags, + ): + super().__init__(world_size, rank) + self.multicast_buffer_ptr = multicast_buffer_ptr + self.buffer_ptrs_dev = buffer_ptrs_dev + self.unicast_ptr = unicast_ptr + self.buffer_M = buffer_M + self.buffer_flags = buffer_flags + + @property + def backend(self) -> str: + return "mnnvl" + + +# ============================================================================ +# BACKEND CHECKS - Hard requirements for decorator +# ============================================================================ + + +@supported_compute_capability([80, 86, 89, 90, 100]) +def _trtllm_workspace_check( + backend: str, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + device: Optional[torch.device], + topology: str, + **kwargs, +) -> bool: + """ + Check if trtllm backend CAN be used for workspace creation. + + Hard requirements: + - SM80+ compute capability (checked by decorator) + - Single-node topology + - Module availability + """ + # trtllm is optimized for single-node + if topology == "multi_node": + return False + + return True + + +@supported_compute_capability([90, 100]) +def _mnnvl_workspace_check( + backend: str, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + device: Optional[torch.device], + topology: str, + **kwargs, +) -> bool: + """ + Check if mnnvl backend CAN be used for workspace creation. + + Hard requirements: + - SM90+ compute capability (checked by decorator) + - Multi-node topology + - Module availability + """ + # MNNVL is designed for multi-node + if topology == "single_node": + return False + + return True + + +# ============================================================================ +# HEURISTIC - Performance-based selection for decorator +# ============================================================================ + + +def _workspace_creation_heuristic( + suitable_backends: list[str], + backend: str, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + device: Optional[torch.device], + topology: str, + **kwargs, +) -> list[str]: + """ + Select best backend for workspace creation based on performance. + + Called by decorator after checking which backends pass requirements. + Uses benchmarking data to pick fastest option. + + Args: + suitable_backends: List of backends that passed hard requirement checks + backend: Requested backend ("auto", "trtllm", or "mnnvl") + world_size: Number of ranks + rank: Current rank + max_token_num: Maximum number of tokens + hidden_dim: Hidden dimension size + dtype: Data type + device: CUDA device + topology: Network topology ("single_node" or "multi_node") + **kwargs: Additional arguments + + Returns: + List containing the selected backend (single element) + """ + if not suitable_backends: + return [] + + if len(suitable_backends) == 1: + return suitable_backends + + # Decision tree based on benchmark data + # TODO: Replace with actual benchmarking results + + # Multi-node: MNNVL is designed for this + if topology == "multi_node": + if "mnnvl" in suitable_backends: + return ["mnnvl"] + + # Single-node scenarios + problem_size = max_token_num * hidden_dim + + # Large problems (>4M elements): trtllm optimized for throughput + if problem_size > 4 * 1024 * 1024: + if "trtllm" in suitable_backends: + return ["trtllm"] + + # Small token counts (<128): trtllm one-shot has better latency + if max_token_num < 128: + if "trtllm" in suitable_backends: + return ["trtllm"] + + # Small world sizes (<=4): trtllm one-shot efficient + if world_size <= 4: + if "trtllm" in suitable_backends: + return ["trtllm"] + + # Default: return first available + return [suitable_backends[0]] + + +# ============================================================================ +# WORKSPACE CREATION - Uses decorator for all validation +# ============================================================================ + + +@backend_requirement( + backend_checks={ + "trtllm": _trtllm_workspace_check, + "mnnvl": _mnnvl_workspace_check, + }, + heuristic_func=_workspace_creation_heuristic, +) +def create_allreduce_fusion_workspace( + backend: Literal["trtllm", "mnnvl", "auto"] = "auto", + world_size: int = None, + rank: int = None, + max_token_num: int = None, + hidden_dim: int = None, + dtype: torch.dtype = None, + device: Optional[torch.device] = None, + topology: str = "single_node", + process_group: Optional["torch.distributed.ProcessGroup"] = None, + **backend_kwargs, +) -> AllReduceFusionWorkspace: + """ + Create workspace for AllReduce fusion operations. + + Backend selection (checks + heuristics) handled by @backend_requirement decorator. + + Args: + backend: Backend to use ("trtllm", "mnnvl", or "auto") + "auto" uses heuristic to select best backend based on topology + and problem size + world_size: Number of ranks in the process group + rank: Current rank ID + max_token_num: Maximum number of tokens to support + hidden_dim: Hidden dimension size + dtype: Data type for communication tensors + device: CUDA device (defaults to current CUDA device) + topology: Network topology hint for backend selection + "single_node" - All ranks on one node (default) + "multi_node" - Ranks span multiple nodes + process_group: PyTorch distributed process group + **backend_kwargs: Additional backend-specific arguments + + Returns: + Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace) + The workspace type determines which backend will be used in allreduce_fusion() + + Raises: + BackendSupportedError: If no suitable backend available for the configuration + ValueError: If problem size not supported for the specified backend + + Examples: + >>> # Auto-select best backend based on topology + >>> workspace = create_allreduce_fusion_workspace( + ... backend="auto", + ... world_size=8, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="single_node" + ... ) + >>> print(workspace.backend) # "trtllm" + + >>> # Explicit backend selection + >>> workspace = create_allreduce_fusion_workspace( + ... backend="mnnvl", + ... world_size=16, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="multi_node" + ... ) + >>> print(workspace.backend) # "mnnvl" + """ + if device is None: + device = torch.device(f"cuda:{torch.cuda.current_device()}") + + # Decorator has validated backend - now create workspace + # If backend="auto", decorator has selected the best one and stored it + + # Get actual backend (decorator resolved "auto" to concrete backend) + if backend == "auto": + # Decorator stored the selected backend in suitable_auto_backends + actual_backend = create_allreduce_fusion_workspace.suitable_auto_backends[0] + else: + actual_backend = backend + + # Create workspace for selected backend + if actual_backend == "trtllm": + from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion + + workspace = trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_size=world_size, + tp_rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + device=device, + process_group=process_group, + **backend_kwargs, + ) + # Ensure workspace has required attributes for our API + if not hasattr(workspace, "world_size"): + workspace.world_size = world_size + if not hasattr(workspace, "rank"): + workspace.rank = rank + return workspace + + elif actual_backend == "mnnvl": + # TODO: Implement create_mnnvl_allreduce_fusion_workspace + # For now, raise NotImplementedError with instructions + raise NotImplementedError( + "MNNVL workspace creation needs to be implemented. " + "Expected function: trtllm_mnnvl_ar.create_mnnvl_allreduce_fusion_workspace" + ) + # from .trtllm_mnnvl_ar import create_mnnvl_allreduce_fusion_workspace + # return create_mnnvl_allreduce_fusion_workspace( + # world_size=world_size, + # rank=rank, + # max_token_num=max_token_num, + # hidden_dim=hidden_dim, + # dtype=dtype, + # device=device, + # **backend_kwargs + # ) + else: + raise RuntimeError(f"Unknown backend: {actual_backend}") + + +# ============================================================================ +# WORKSPACE DESTRUCTION +# ============================================================================ + + +def destroy_allreduce_fusion_workspace(workspace: AllReduceFusionWorkspace) -> None: + """ + Destroy workspace and free resources. + + Automatically detects workspace type from the object and calls + appropriate cleanup function. + + Args: + workspace: Workspace object to destroy + + Example: + >>> workspace = create_allreduce_fusion_workspace(...) + >>> # ... use workspace ... + >>> destroy_allreduce_fusion_workspace(workspace) + """ + if isinstance(workspace, TRTLLMAllReduceFusionWorkspace): + from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion + + trtllm_destroy_ipc_workspace_for_all_reduce_fusion(workspace) + elif isinstance(workspace, MNNVLAllReduceFusionWorkspace): + # TODO: Implement MNNVL workspace destruction + raise NotImplementedError("MNNVL workspace destruction not yet implemented") + # from .trtllm_mnnvl_ar import destroy_mnnvl_allreduce_fusion_workspace + # destroy_mnnvl_allreduce_fusion_workspace(workspace) + else: + raise TypeError(f"Unknown workspace type: {type(workspace)}") + + +# ============================================================================ +# MAIN API - NO backend parameter, infers from workspace type +# ============================================================================ + + +def allreduce_fusion( + input: torch.Tensor, + workspace: AllReduceFusionWorkspace, + launch_with_pdl: bool = False, + # ===== OUTPUT tensors (pre-allocated, will be filled) ===== + output: Optional[torch.Tensor] = None, + residual_out: Optional[torch.Tensor] = None, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, + scale_out: Optional[torch.Tensor] = None, + # ===== INPUT parameters ===== + residual_in: Optional[torch.Tensor] = None, + rms_gamma: Optional[torch.Tensor] = None, + rms_eps: float = 1e-6, + scale_factor: Optional[Union[torch.Tensor, float]] = None, + layout_code: Optional[int] = None, + # ===== Control parameters ===== + pattern: Optional[int] = None, + use_oneshot: Optional[bool] = None, + fp32_acc: bool = False, + metadata: Optional[dict] = None, +) -> torch.Tensor: + """ + AllReduce + RMSNorm fusion operation. + + Backend is automatically determined from workspace type. + No backend parameter needed! + + Supports multiple fusion patterns: + - AllReduce only + - AllReduce + Residual + RMSNorm + - AllReduce + Residual + RMSNorm + Quantization (FP8/FP4) + + Args: + input: Input tensor [token_num, hidden_dim] + workspace: Workspace object (type determines backend) + launch_with_pdl: Use Persistent Device Launch + + # ===== OUTPUT tensors (pre-allocated, filled by function) ===== + output: AllReduce output [token_num, hidden_dim] + residual_out: Prenorm output (after residual add, before norm) [token_num, hidden_dim] + norm_out: Normalized output [token_num, hidden_dim] + quant_out: Quantized output [token_num, hidden_dim] [trtllm only] + scale_out: Quantization scale factors [trtllm only] + + # ===== INPUT parameters ===== + residual_in: Residual tensor to ADD [token_num, hidden_dim] + rms_gamma: RMSNorm weight [hidden_dim] + rms_eps: RMSNorm epsilon for numerical stability + scale_factor: Input scale factor for quantization [trtllm only] + layout_code: Scale factor layout (QuantizationSFLayout) [trtllm only] + + # ===== Control parameters ===== + pattern: Fusion pattern (AllReduceFusionPattern) + If None, auto-detected based on provided output tensors + use_oneshot: [trtllm only] Use oneshot strategy vs twoshot + If None, uses internal heuristics + fp32_acc: [trtllm only] Use FP32 accumulation for AllReduce + metadata: [trtllm only] Workspace metadata for validation + + Returns: + Output tensor (typically norm_out for fusion cases, output otherwise) + + Examples: + >>> # Basic AllReduce + Residual + RMSNorm + >>> workspace = create_allreduce_fusion_workspace( + ... backend="auto", + ... world_size=8, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="single_node" + ... ) + >>> + >>> # Pre-allocate output tensors + >>> prenorm = torch.empty_like(hidden_states) + >>> normed = torch.empty_like(hidden_states) + >>> + >>> # Call fusion - backend inferred from workspace type + >>> output = allreduce_fusion( + ... input=hidden_states, + ... workspace=workspace, + ... launch_with_pdl=True, + ... residual_out=prenorm, + ... norm_out=normed, + ... residual_in=residual, + ... rms_gamma=norm_weight + ... ) + >>> # output == normed (final result) + + >>> # With FP8 quantization + >>> quant = torch.empty_like(hidden_states, dtype=torch.float8_e4m3fn) + >>> scales = torch.empty(token_num * hidden_dim // 16, dtype=torch.float16) + >>> + >>> output = allreduce_fusion( + ... input=hidden_states, + ... workspace=workspace, + ... norm_out=normed, + ... quant_out=quant, + ... scale_out=scales, + ... residual_in=residual, + ... rms_gamma=norm_weight, + ... scale_factor=scale_tensor + ... ) + """ + # Auto-detect pattern if not provided + if pattern is None: + pattern = _infer_fusion_pattern( + output, residual_in, residual_out, norm_out, quant_out, scale_out + ) + + # Infer backend from workspace type and dispatch + if isinstance(workspace, TRTLLMAllReduceFusionWorkspace): + return _allreduce_fusion_trtllm( + input=input, + workspace=workspace, + launch_with_pdl=launch_with_pdl, + output=output, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + quant_out=quant_out, + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=layout_code, + pattern=pattern, + use_oneshot=use_oneshot, + fp32_acc=fp32_acc, + metadata=metadata, + ) + elif isinstance(workspace, MNNVLAllReduceFusionWorkspace): + return _allreduce_fusion_mnnvl( + input=input, + workspace=workspace, + launch_with_pdl=launch_with_pdl, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + ) + else: + raise TypeError( + f"Unknown workspace type: {type(workspace)}. " + f"Expected TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace" + ) + + +# ============================================================================ +# HELPER FUNCTIONS +# ============================================================================ + + +def _infer_fusion_pattern( + output, residual_in, residual_out, norm_out, quant_out, scale_out +) -> int: + """ + Automatically infer fusion pattern from provided tensors. + + Returns AllReduceFusionPattern value based on which output tensors are provided. + """ + from .trtllm_ar import AllReduceFusionPattern + + if quant_out is not None: + # Quantization patterns + if norm_out is not None and residual_out is not None: + # Has separate norm output and residual output + return AllReduceFusionPattern.kARResidualRMSNormOutFP8Quant # 4 + else: + # Quant without separate outputs + return AllReduceFusionPattern.kARResidualRMSNormFP8Quant # 2 + elif norm_out is not None: + # RMS Norm without quantization + return AllReduceFusionPattern.kARResidualRMSNorm # 1 + else: + # Just AllReduce + return AllReduceFusionPattern.kAllReduce # 0 + + +def _allreduce_fusion_trtllm( + input: torch.Tensor, + workspace: TRTLLMAllReduceFusionWorkspace, + launch_with_pdl: bool, + output: Optional[torch.Tensor], + residual_in: Optional[torch.Tensor], + residual_out: Optional[torch.Tensor], + norm_out: Optional[torch.Tensor], + quant_out: Optional[torch.Tensor], + scale_out: Optional[torch.Tensor], + rms_gamma: Optional[torch.Tensor], + rms_eps: float, + scale_factor: Optional[Union[torch.Tensor, float]], + layout_code: Optional[int], + pattern: int, + use_oneshot: Optional[bool], + fp32_acc: bool, + metadata: Optional[dict], +) -> torch.Tensor: + """TensorRT-LLM backend implementation.""" + from .trtllm_ar import trtllm_allreduce_fusion + + token_num, hidden_dim = input.shape + + if output is None: + output = torch.empty_like(input) + + trtllm_allreduce_fusion( + allreduce_in=input, + world_size=workspace.world_size, + world_rank=workspace.rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace.workspace_ptrs, + launch_with_pdl=launch_with_pdl, + trigger_completion_at_end=launch_with_pdl, # Same meaning + fp32_acc=fp32_acc, + pattern_code=pattern, + use_oneshot=use_oneshot, + allreduce_out=output, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + quant_out=quant_out, + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=layout_code, + metadata=metadata, + ) + + # Return the most downstream output + if norm_out is not None: + return norm_out + elif quant_out is not None: + return quant_out + else: + return output + + +def _allreduce_fusion_mnnvl( + input: torch.Tensor, + workspace: MNNVLAllReduceFusionWorkspace, + launch_with_pdl: bool, + residual_in: Optional[torch.Tensor], + residual_out: Optional[torch.Tensor], + norm_out: Optional[torch.Tensor], + rms_gamma: Optional[torch.Tensor], + rms_eps: float, +) -> torch.Tensor: + """ + MNNVL backend implementation. + + Calls trtllm_mnnvl_fused_allreduce_rmsnorm which performs: + 1. AllReduce on input + 2. Add residual + 3. RMSNorm + """ + from .trtllm_mnnvl_ar import trtllm_mnnvl_fused_allreduce_rmsnorm + + # Validate required parameters for RMS fusion + if residual_in is None: + raise ValueError("MNNVL AllReduce+RMS fusion requires residual_in") + if residual_out is None: + raise ValueError( + "MNNVL AllReduce+RMS fusion requires residual_out (prenorm_output)" + ) + if norm_out is None: + raise ValueError("MNNVL AllReduce+RMS fusion requires norm_out (normed_output)") + if rms_gamma is None: + raise ValueError("MNNVL AllReduce+RMS fusion requires rms_gamma") + + # Call the MNNVL fusion function + trtllm_mnnvl_fused_allreduce_rmsnorm( + prenorm_output=residual_out, + normed_output=norm_out, + shard_input=input, + multicast_buffer_ptr=workspace.multicast_buffer_ptr, + buffer_ptrs_dev=workspace.buffer_ptrs_dev, + unicast_ptr=workspace.unicast_ptr, + buffer_M=workspace.buffer_M, + buffer_flags_mnnvl=workspace.buffer_flags, + nranks=workspace.world_size, + rank=workspace.rank, + gamma=rms_gamma, + epsilon=rms_eps, + residual=residual_in, + launch_with_pdl=launch_with_pdl, + ) + + return norm_out + + +# ============================================================================ +# CONTEXT MANAGER +# ============================================================================ + + +class AllReduceFusionContext: + """ + Context manager with automatic workspace management. + + This provides a convenient high-level API that handles workspace + creation and cleanup automatically. + + Example: + >>> with AllReduceFusionContext( + ... backend="auto", + ... world_size=8, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="single_node" + ... ) as ctx: + ... for batch in training_loop: + ... prenorm = torch.empty_like(batch.hidden_states) + ... normed = torch.empty_like(batch.hidden_states) + ... + ... output = ctx.allreduce_fusion( + ... input=batch.hidden_states, + ... residual_out=prenorm, + ... norm_out=normed, + ... residual_in=batch.residual, + ... rms_gamma=model.norm_weight, + ... launch_with_pdl=True + ... ) + >>> # Workspace automatically cleaned up + """ + + def __init__( + self, + backend: Literal["trtllm", "mnnvl", "auto"] = "auto", + world_size: int = None, + rank: int = None, + max_token_num: int = None, + hidden_dim: int = None, + dtype: torch.dtype = None, + device: Optional[torch.device] = None, + topology: str = "single_node", + **kwargs, + ): + """ + Initialize context manager. + + Args: + backend: Backend to use ("trtllm", "mnnvl", or "auto") + world_size: Number of ranks + rank: Current rank + max_token_num: Maximum tokens to support + hidden_dim: Hidden dimension + dtype: Data type + device: CUDA device + topology: Network topology ("single_node" or "multi_node") + **kwargs: Additional backend-specific arguments + """ + # Workspace creation does all the selection logic via decorator + self.workspace = create_allreduce_fusion_workspace( + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + device=device, + topology=topology, + **kwargs, + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + destroy_allreduce_fusion_workspace(self.workspace) + + def allreduce_fusion(self, input: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Call allreduce_fusion with the managed workspace. + + Args: + input: Input tensor + **kwargs: Additional arguments passed to allreduce_fusion() + + Returns: + Output tensor + """ + return allreduce_fusion(input=input, workspace=self.workspace, **kwargs) From 1917c765ff11d3132f004d537655a70fc9c6dad5 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Fri, 21 Nov 2025 11:33:21 -0800 Subject: [PATCH 02/11] Polished the interface --- flashinfer/comm/allreduce.py | 349 ++++++++++++++++++----------------- 1 file changed, 181 insertions(+), 168 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index a8f6498c0c..e30ca03606 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -54,10 +54,24 @@ import torch from ..utils import backend_requirement, supported_compute_capability +from .trtllm_ar import trtllm_allreduce_fusion +from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion +from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion # ============================================================================ -# WORKSPACE BASE CLASS +# WORKSPACE BASE CLASS AND IMPLEMENTATIONS +# ============================================================================ +# +# Workspace classes wrap the underlying backend workspace implementations: +# - TRTLLMAllReduceFusionWorkspace: Wraps trtllm_create_ipc_workspace_for_all_reduce_fusion +# - MNNVLAllReduceFusionWorkspace: Wraps MNNVL workspace (to be implemented) +# +# Each workspace: +# 1. Calls the backend-specific workspace creation function in __init__ +# 2. Stores the internal workspace as _internal_workspace +# 3. Exposes essential attributes for the unified API +# 4. Can be destroyed using destroy_allreduce_fusion_workspace() # ============================================================================ @@ -67,6 +81,7 @@ class AllReduceFusionWorkspace(ABC): def __init__(self, world_size: int, rank: int): self.world_size = world_size self.rank = rank + self._destroyed = False @property @abstractmethod @@ -74,14 +89,105 @@ def backend(self) -> str: """Return backend name.""" pass + @abstractmethod + def destroy(self) -> None: + """ + Destroy workspace and free resources. + + This should be called explicitly when done using the workspace. + Prefer using AllReduceFusionContext context manager for automatic cleanup. + """ + pass + + def __del__(self): + """ + Destructor - safety net if destroy() wasn't called explicitly. + + Warns if cleanup wasn't done properly. Not recommended to rely on this + as __del__ timing is non-deterministic and can cause issues with + distributed/CUDA resources. + """ + if not self._destroyed: + import warnings + + warnings.warn( + f"{self.__class__.__name__} was not explicitly destroyed. " + f"Call workspace.destroy() or use AllReduceFusionContext to ensure " + f"proper cleanup of distributed/CUDA resources.", + ResourceWarning, + stacklevel=2, + ) + try: + self.destroy() + except Exception as e: + # Can't raise in __del__, just warn + warnings.warn( + f"Error during automatic cleanup of {self.__class__.__name__}: {e}", + ResourceWarning, + stacklevel=2, + ) + class TRTLLMAllReduceFusionWorkspace(AllReduceFusionWorkspace): """TensorRT-LLM workspace for AllReduce fusion.""" - def __init__(self, world_size: int, rank: int, workspace_ptrs, metadata): - super().__init__(world_size, rank) - self.workspace_ptrs = workspace_ptrs - self.metadata = metadata + def __init__( + self, + tp_size: int, + tp_rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + device: torch.device, + process_group: Optional["torch.distributed.ProcessGroup"] = None, + **kwargs, + ): + """ + Create TensorRT-LLM AllReduce fusion workspace. + + Args: + tp_size: Tensor parallel size (world size) + tp_rank: Tensor parallel rank + max_token_num: Maximum number of tokens + hidden_dim: Hidden dimension size + dtype: Data type + device: CUDA device + process_group: PyTorch distributed process group + **kwargs: Additional arguments for workspace creation + """ + super().__init__(tp_size, tp_rank) + + # Call the actual workspace creation function + self._internal_workspace = trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_size=tp_size, + tp_rank=tp_rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + device=device, + process_group=process_group, + **kwargs, + ) + + # Store essential attributes for easy access + self.workspace_ptrs = self._internal_workspace.workspace_ptrs + self.metadata = self._internal_workspace.metadata + + def __getattr__(self, name): + """Delegate attribute access to internal workspace if not found.""" + if name.startswith("_"): + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + return getattr(self._internal_workspace, name) + + def destroy(self) -> None: + """Destroy workspace and free resources.""" + if self._destroyed: + return # Already destroyed, nothing to do + + trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self._internal_workspace) + self._destroyed = True @property def backend(self) -> str: @@ -95,18 +201,64 @@ def __init__( self, world_size: int, rank: int, - multicast_buffer_ptr: int, - buffer_ptrs_dev: int, - unicast_ptr: int, - buffer_M: int, - buffer_flags, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + device: torch.device, + **kwargs, ): + """ + Create MNNVL AllReduce fusion workspace. + + Args: + world_size: Number of ranks + rank: Current rank + max_token_num: Maximum number of tokens + hidden_dim: Hidden dimension size + dtype: Data type + device: CUDA device + **kwargs: Additional arguments for workspace creation + """ super().__init__(world_size, rank) - self.multicast_buffer_ptr = multicast_buffer_ptr - self.buffer_ptrs_dev = buffer_ptrs_dev - self.unicast_ptr = unicast_ptr - self.buffer_M = buffer_M - self.buffer_flags = buffer_flags + + # TODO: Import and call the actual MNNVL workspace creation function + # For now, raise NotImplementedError + raise NotImplementedError( + "MNNVL workspace creation needs to be implemented in trtllm_mnnvl_ar.py. " + "Expected function: create_mnnvl_allreduce_fusion_workspace" + ) + + # When implemented, should look like: + # from .trtllm_mnnvl_ar import create_mnnvl_allreduce_fusion_workspace + # + # self._internal_workspace = create_mnnvl_allreduce_fusion_workspace( + # world_size=world_size, + # rank=rank, + # max_token_num=max_token_num, + # hidden_dim=hidden_dim, + # dtype=dtype, + # device=device, + # **kwargs, + # ) + # + # # Store essential attributes for easy access + # self.multicast_buffer_ptr = self._internal_workspace.multicast_buffer_ptr + # self.buffer_ptrs_dev = self._internal_workspace.buffer_ptrs_dev + # self.unicast_ptr = self._internal_workspace.unicast_ptr + # self.buffer_M = self._internal_workspace.buffer_M + # self.buffer_flags = self._internal_workspace.buffer_flags + + def destroy(self) -> None: + """Destroy workspace and free resources.""" + if self._destroyed: + return # Already destroyed, nothing to do + + # TODO: Implement MNNVL workspace destruction + self._destroyed = True + raise NotImplementedError("MNNVL workspace destruction not yet implemented") + # from .trtllm_mnnvl_ar import destroy_mnnvl_allreduce_fusion_workspace + # destroy_mnnvl_allreduce_fusion_workspace(self._internal_workspace) + # self._destroyed = True @property def backend(self) -> str: @@ -337,11 +489,9 @@ def create_allreduce_fusion_workspace( else: actual_backend = backend - # Create workspace for selected backend + # Create workspace for selected backend using workspace constructors if actual_backend == "trtllm": - from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion - - workspace = trtllm_create_ipc_workspace_for_all_reduce_fusion( + return TRTLLMAllReduceFusionWorkspace( tp_size=world_size, tp_rank=rank, max_token_num=max_token_num, @@ -351,30 +501,17 @@ def create_allreduce_fusion_workspace( process_group=process_group, **backend_kwargs, ) - # Ensure workspace has required attributes for our API - if not hasattr(workspace, "world_size"): - workspace.world_size = world_size - if not hasattr(workspace, "rank"): - workspace.rank = rank - return workspace elif actual_backend == "mnnvl": - # TODO: Implement create_mnnvl_allreduce_fusion_workspace - # For now, raise NotImplementedError with instructions - raise NotImplementedError( - "MNNVL workspace creation needs to be implemented. " - "Expected function: trtllm_mnnvl_ar.create_mnnvl_allreduce_fusion_workspace" + return MNNVLAllReduceFusionWorkspace( + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + device=device, + **backend_kwargs, ) - # from .trtllm_mnnvl_ar import create_mnnvl_allreduce_fusion_workspace - # return create_mnnvl_allreduce_fusion_workspace( - # world_size=world_size, - # rank=rank, - # max_token_num=max_token_num, - # hidden_dim=hidden_dim, - # dtype=dtype, - # device=device, - # **backend_kwargs - # ) else: raise RuntimeError(f"Unknown backend: {actual_backend}") @@ -388,8 +525,7 @@ def destroy_allreduce_fusion_workspace(workspace: AllReduceFusionWorkspace) -> N """ Destroy workspace and free resources. - Automatically detects workspace type from the object and calls - appropriate cleanup function. + This is a convenience function that calls the workspace's destroy() method. Args: workspace: Workspace object to destroy @@ -398,18 +534,9 @@ def destroy_allreduce_fusion_workspace(workspace: AllReduceFusionWorkspace) -> N >>> workspace = create_allreduce_fusion_workspace(...) >>> # ... use workspace ... >>> destroy_allreduce_fusion_workspace(workspace) + >>> # Or call directly: workspace.destroy() """ - if isinstance(workspace, TRTLLMAllReduceFusionWorkspace): - from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion - - trtllm_destroy_ipc_workspace_for_all_reduce_fusion(workspace) - elif isinstance(workspace, MNNVLAllReduceFusionWorkspace): - # TODO: Implement MNNVL workspace destruction - raise NotImplementedError("MNNVL workspace destruction not yet implemented") - # from .trtllm_mnnvl_ar import destroy_mnnvl_allreduce_fusion_workspace - # destroy_mnnvl_allreduce_fusion_workspace(workspace) - else: - raise TypeError(f"Unknown workspace type: {type(workspace)}") + workspace.destroy() # ============================================================================ @@ -619,7 +746,6 @@ def _allreduce_fusion_trtllm( metadata: Optional[dict], ) -> torch.Tensor: """TensorRT-LLM backend implementation.""" - from .trtllm_ar import trtllm_allreduce_fusion token_num, hidden_dim = input.shape @@ -678,8 +804,6 @@ def _allreduce_fusion_mnnvl( 2. Add residual 3. RMSNorm """ - from .trtllm_mnnvl_ar import trtllm_mnnvl_fused_allreduce_rmsnorm - # Validate required parameters for RMS fusion if residual_in is None: raise ValueError("MNNVL AllReduce+RMS fusion requires residual_in") @@ -693,117 +817,6 @@ def _allreduce_fusion_mnnvl( raise ValueError("MNNVL AllReduce+RMS fusion requires rms_gamma") # Call the MNNVL fusion function - trtllm_mnnvl_fused_allreduce_rmsnorm( - prenorm_output=residual_out, - normed_output=norm_out, - shard_input=input, - multicast_buffer_ptr=workspace.multicast_buffer_ptr, - buffer_ptrs_dev=workspace.buffer_ptrs_dev, - unicast_ptr=workspace.unicast_ptr, - buffer_M=workspace.buffer_M, - buffer_flags_mnnvl=workspace.buffer_flags, - nranks=workspace.world_size, - rank=workspace.rank, - gamma=rms_gamma, - epsilon=rms_eps, - residual=residual_in, - launch_with_pdl=launch_with_pdl, - ) + raise NotImplementedError("MNNVL AllReduce+RMS fusion is not implemented") return norm_out - - -# ============================================================================ -# CONTEXT MANAGER -# ============================================================================ - - -class AllReduceFusionContext: - """ - Context manager with automatic workspace management. - - This provides a convenient high-level API that handles workspace - creation and cleanup automatically. - - Example: - >>> with AllReduceFusionContext( - ... backend="auto", - ... world_size=8, - ... rank=0, - ... max_token_num=2048, - ... hidden_dim=4096, - ... dtype=torch.bfloat16, - ... topology="single_node" - ... ) as ctx: - ... for batch in training_loop: - ... prenorm = torch.empty_like(batch.hidden_states) - ... normed = torch.empty_like(batch.hidden_states) - ... - ... output = ctx.allreduce_fusion( - ... input=batch.hidden_states, - ... residual_out=prenorm, - ... norm_out=normed, - ... residual_in=batch.residual, - ... rms_gamma=model.norm_weight, - ... launch_with_pdl=True - ... ) - >>> # Workspace automatically cleaned up - """ - - def __init__( - self, - backend: Literal["trtllm", "mnnvl", "auto"] = "auto", - world_size: int = None, - rank: int = None, - max_token_num: int = None, - hidden_dim: int = None, - dtype: torch.dtype = None, - device: Optional[torch.device] = None, - topology: str = "single_node", - **kwargs, - ): - """ - Initialize context manager. - - Args: - backend: Backend to use ("trtllm", "mnnvl", or "auto") - world_size: Number of ranks - rank: Current rank - max_token_num: Maximum tokens to support - hidden_dim: Hidden dimension - dtype: Data type - device: CUDA device - topology: Network topology ("single_node" or "multi_node") - **kwargs: Additional backend-specific arguments - """ - # Workspace creation does all the selection logic via decorator - self.workspace = create_allreduce_fusion_workspace( - backend=backend, - world_size=world_size, - rank=rank, - max_token_num=max_token_num, - hidden_dim=hidden_dim, - dtype=dtype, - device=device, - topology=topology, - **kwargs, - ) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - destroy_allreduce_fusion_workspace(self.workspace) - - def allreduce_fusion(self, input: torch.Tensor, **kwargs) -> torch.Tensor: - """ - Call allreduce_fusion with the managed workspace. - - Args: - input: Input tensor - **kwargs: Additional arguments passed to allreduce_fusion() - - Returns: - Output tensor - """ - return allreduce_fusion(input=input, workspace=self.workspace, **kwargs) From 9225e5222425f4964a70c104b82c00d62c0af47a Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Fri, 21 Nov 2025 14:07:33 -0800 Subject: [PATCH 03/11] Removed device param --- flashinfer/comm/allreduce.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index e30ca03606..e7699c8d0b 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -138,7 +138,6 @@ def __init__( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - device: torch.device, process_group: Optional["torch.distributed.ProcessGroup"] = None, **kwargs, ): @@ -151,7 +150,6 @@ def __init__( max_token_num: Maximum number of tokens hidden_dim: Hidden dimension size dtype: Data type - device: CUDA device process_group: PyTorch distributed process group **kwargs: Additional arguments for workspace creation """ @@ -159,12 +157,10 @@ def __init__( # Call the actual workspace creation function self._internal_workspace = trtllm_create_ipc_workspace_for_all_reduce_fusion( - tp_size=tp_size, tp_rank=tp_rank, + tp_size=tp_size, max_token_num=max_token_num, hidden_dim=hidden_dim, - dtype=dtype, - device=device, process_group=process_group, **kwargs, ) @@ -204,7 +200,6 @@ def __init__( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - device: torch.device, **kwargs, ): """ @@ -216,7 +211,6 @@ def __init__( max_token_num: Maximum number of tokens hidden_dim: Hidden dimension size dtype: Data type - device: CUDA device **kwargs: Additional arguments for workspace creation """ super().__init__(world_size, rank) @@ -229,15 +223,12 @@ def __init__( ) # When implemented, should look like: - # from .trtllm_mnnvl_ar import create_mnnvl_allreduce_fusion_workspace - # # self._internal_workspace = create_mnnvl_allreduce_fusion_workspace( # world_size=world_size, # rank=rank, # max_token_num=max_token_num, # hidden_dim=hidden_dim, # dtype=dtype, - # device=device, # **kwargs, # ) # @@ -278,7 +269,6 @@ def _trtllm_workspace_check( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - device: Optional[torch.device], topology: str, **kwargs, ) -> bool: @@ -305,7 +295,6 @@ def _mnnvl_workspace_check( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - device: Optional[torch.device], topology: str, **kwargs, ) -> bool: @@ -337,7 +326,6 @@ def _workspace_creation_heuristic( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - device: Optional[torch.device], topology: str, **kwargs, ) -> list[str]: @@ -355,7 +343,6 @@ def _workspace_creation_heuristic( max_token_num: Maximum number of tokens hidden_dim: Hidden dimension size dtype: Data type - device: CUDA device topology: Network topology ("single_node" or "multi_node") **kwargs: Additional arguments @@ -417,7 +404,6 @@ def create_allreduce_fusion_workspace( max_token_num: int = None, hidden_dim: int = None, dtype: torch.dtype = None, - device: Optional[torch.device] = None, topology: str = "single_node", process_group: Optional["torch.distributed.ProcessGroup"] = None, **backend_kwargs, @@ -436,7 +422,6 @@ def create_allreduce_fusion_workspace( max_token_num: Maximum number of tokens to support hidden_dim: Hidden dimension size dtype: Data type for communication tensors - device: CUDA device (defaults to current CUDA device) topology: Network topology hint for backend selection "single_node" - All ranks on one node (default) "multi_node" - Ranks span multiple nodes @@ -476,9 +461,6 @@ def create_allreduce_fusion_workspace( ... ) >>> print(workspace.backend) # "mnnvl" """ - if device is None: - device = torch.device(f"cuda:{torch.cuda.current_device()}") - # Decorator has validated backend - now create workspace # If backend="auto", decorator has selected the best one and stored it @@ -496,8 +478,6 @@ def create_allreduce_fusion_workspace( tp_rank=rank, max_token_num=max_token_num, hidden_dim=hidden_dim, - dtype=dtype, - device=device, process_group=process_group, **backend_kwargs, ) @@ -509,7 +489,6 @@ def create_allreduce_fusion_workspace( max_token_num=max_token_num, hidden_dim=hidden_dim, dtype=dtype, - device=device, **backend_kwargs, ) else: @@ -580,7 +559,7 @@ def allreduce_fusion( Args: input: Input tensor [token_num, hidden_dim] workspace: Workspace object (type determines backend) - launch_with_pdl: Use Persistent Device Launch + launch_with_pdl: Use Persistent Dependency Launch # ===== OUTPUT tensors (pre-allocated, filled by function) ===== output: AllReduce output [token_num, hidden_dim] From 84e75e4a4bfbfd5134e7acbde8218c7ee548d0b0 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Fri, 21 Nov 2025 15:06:53 -0800 Subject: [PATCH 04/11] Updated test with legacy vs unified API --- flashinfer/comm/__init__.py | 1 - flashinfer/comm/allreduce.py | 40 ++-- tests/comm/test_trtllm_allreduce_fusion.py | 236 +++++++++++++++------ 3 files changed, 195 insertions(+), 82 deletions(-) diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index b0e9dfd0a4..3ffad2505d 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -40,7 +40,6 @@ from .vllm_ar import register_graph_buffers as vllm_register_graph_buffers # Unified AllReduce Fusion API -from .allreduce import AllReduceFusionContext as AllReduceFusionContext from .allreduce import AllReduceFusionWorkspace as AllReduceFusionWorkspace from .allreduce import MNNVLAllReduceFusionWorkspace as MNNVLAllReduceFusionWorkspace from .allreduce import TRTLLMAllReduceFusionWorkspace as TRTLLMAllReduceFusionWorkspace diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index e7699c8d0b..693d052723 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -137,7 +137,6 @@ def __init__( tp_rank: int, max_token_num: int, hidden_dim: int, - dtype: torch.dtype, process_group: Optional["torch.distributed.ProcessGroup"] = None, **kwargs, ): @@ -161,13 +160,14 @@ def __init__( tp_size=tp_size, max_token_num=max_token_num, hidden_dim=hidden_dim, - process_group=process_group, + group=process_group, **kwargs, ) # Store essential attributes for easy access - self.workspace_ptrs = self._internal_workspace.workspace_ptrs - self.metadata = self._internal_workspace.metadata + self.ipc_handles = self._internal_workspace[0] + self.workspace_tensor = self._internal_workspace[1] + self.metadata = self._internal_workspace[2] def __getattr__(self, name): """Delegate attribute access to internal workspace if not found.""" @@ -726,37 +726,49 @@ def _allreduce_fusion_trtllm( ) -> torch.Tensor: """TensorRT-LLM backend implementation.""" + # Extract shape from 2D input token_num, hidden_dim = input.shape + # Allocate output if needed (keep 2D shape) if output is None: output = torch.empty_like(input) + # Flatten all tensors to 1D for legacy trtllm_allreduce_fusion API + # The legacy API expects flattened tensors and explicit token_num/hidden_dim + input_flat = input.flatten() + output_flat = output.flatten() + residual_in_flat = residual_in.flatten() if residual_in is not None else None + residual_out_flat = residual_out.flatten() if residual_out is not None else None + norm_out_flat = norm_out.flatten() if norm_out is not None else None + quant_out_flat = quant_out.flatten() if quant_out is not None else None + + # Call legacy API with flattened tensors trtllm_allreduce_fusion( - allreduce_in=input, + allreduce_in=input_flat, world_size=workspace.world_size, world_rank=workspace.rank, token_num=token_num, hidden_dim=hidden_dim, - workspace_ptrs=workspace.workspace_ptrs, + workspace_ptrs=workspace.workspace_tensor, launch_with_pdl=launch_with_pdl, trigger_completion_at_end=launch_with_pdl, # Same meaning fp32_acc=fp32_acc, pattern_code=pattern, use_oneshot=use_oneshot, - allreduce_out=output, - residual_in=residual_in, - residual_out=residual_out, - norm_out=norm_out, - quant_out=quant_out, - scale_out=scale_out, - rms_gamma=rms_gamma, + allreduce_out=output_flat, + residual_in=residual_in_flat, + residual_out=residual_out_flat, + norm_out=norm_out_flat, + quant_out=quant_out_flat, + scale_out=scale_out, # scale_out is not reshaped + rms_gamma=rms_gamma, # 1D tensor, no reshape needed rms_eps=rms_eps, scale_factor=scale_factor, layout_code=layout_code, metadata=metadata, ) - # Return the most downstream output + # Return the most downstream output (already in 2D shape from input views) if norm_out is not None: return norm_out elif quant_out is not None: diff --git a/tests/comm/test_trtllm_allreduce_fusion.py b/tests/comm/test_trtllm_allreduce_fusion.py index c3aa8c8252..17d9d7c2d4 100644 --- a/tests/comm/test_trtllm_allreduce_fusion.py +++ b/tests/comm/test_trtllm_allreduce_fusion.py @@ -22,7 +22,9 @@ SCALE_FACTOR_RANGE = (-1, 1) -def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_init_port): +def _run_correctness_worker( + world_size, rank, dtype, hidden_dim, distributed_init_port, legacy_api=True +): device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) distributed_init_method = f"tcp://localhost:{distributed_init_port}" @@ -57,18 +59,37 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini lamport_use_fp32 = dtype == torch.float32 - # create workspace for allreduce fusion with metadata - ipc_handles, workspace_tensor, workspace_metadata = ( - comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( - rank, - world_size, - MAX_TOKEN_NUM, - hidden_dim, - group=group, + # Create workspace - choose between legacy and new API + if legacy_api: + # Legacy API: create workspace for allreduce fusion with metadata + ipc_handles, workspace_tensor, workspace_metadata = ( + comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + rank, + world_size, + MAX_TOKEN_NUM, + hidden_dim, + group=group, + use_fp32_lamport=lamport_use_fp32, + create_metadata=True, # Get metadata for validation + ) + ) + else: + workspace = None + # New unified API: create workspace + workspace = comm.create_allreduce_fusion_workspace( + backend="trtllm", + world_size=world_size, + rank=rank, + max_token_num=MAX_TOKEN_NUM, + hidden_dim=hidden_dim, + dtype=dtype, + topology="single_node", + process_group=group, use_fp32_lamport=lamport_use_fp32, - create_metadata=True, # Get metadata for validation + create_metadata=True, ) - ) + # Extract metadata for compatibility with tests + workspace_metadata = workspace.metadata test_loop = 5 @@ -163,60 +184,130 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(test_loop): - comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - world_size=world_size, - world_rank=rank, - token_num=token_num, - hidden_dim=hidden_dim, - workspace_ptrs=workspace_tensor, - launch_with_pdl=launch_with_pdl, - use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=all_reduce_out, - residual_in=residual_in, - residual_out=residual_out, - norm_out=norm_out, - quant_out=quant_out, - scale_out=scale_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - scale_factor=scale_factor, - layout_code=swizzled_layout_code, - metadata=workspace_metadata, - ) + if legacy_api: + # Legacy API - uses flattened tensors + comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + world_size=world_size, + world_rank=rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace_tensor, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=all_reduce_out, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + quant_out=quant_out, + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + metadata=workspace_metadata, + ) + else: + # New unified API - expects 2D tensors [token_num, hidden_dim] + comm.allreduce_fusion( + input=allreduce_in.view( + token_num, hidden_dim + ), + workspace=workspace, + launch_with_pdl=launch_with_pdl, + output=all_reduce_out.view( + token_num, hidden_dim + ), + residual_in=residual_in.view( + token_num, hidden_dim + ), + residual_out=residual_out.view( + token_num, hidden_dim + ), + norm_out=norm_out.view( + token_num, hidden_dim + ), + quant_out=quant_out.view( + token_num, hidden_dim + ), + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + pattern=pattern_code, + use_oneshot=use_oneshot, + fp32_acc=fp32_acc, + metadata=workspace_metadata, + ) # NOTE: in real case, you dont have to set all optional params. You could set those required by fusion pattern. # capture g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): for _ in range(test_loop): - comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - world_size=world_size, - world_rank=rank, - token_num=token_num, - hidden_dim=hidden_dim, - workspace_ptrs=workspace_tensor, - launch_with_pdl=launch_with_pdl, - use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=all_reduce_out, - residual_in=residual_in, - residual_out=residual_out, - norm_out=norm_out, - quant_out=quant_out, - scale_out=scale_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - scale_factor=scale_factor, - layout_code=swizzled_layout_code, - metadata=workspace_metadata, - ) + if legacy_api: + # Legacy API - uses flattened tensors + comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + world_size=world_size, + world_rank=rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace_tensor, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=all_reduce_out, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + quant_out=quant_out, + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + metadata=workspace_metadata, + ) + else: + # New unified API - expects 2D tensors [token_num, hidden_dim] + comm.allreduce_fusion( + input=allreduce_in.view( + token_num, hidden_dim + ), + workspace=workspace, + launch_with_pdl=launch_with_pdl, + output=all_reduce_out.view( + token_num, hidden_dim + ), + residual_in=residual_in.view( + token_num, hidden_dim + ), + residual_out=residual_out.view( + token_num, hidden_dim + ), + norm_out=norm_out.view( + token_num, hidden_dim + ), + quant_out=quant_out.view( + token_num, hidden_dim + ), + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + pattern=pattern_code, + use_oneshot=use_oneshot, + fp32_acc=fp32_acc, + metadata=workspace_metadata, + ) # replay g.replay() torch.cuda.synchronize() @@ -307,9 +398,14 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini finally: dist.barrier(group=group) - comm.trtllm_destroy_ipc_workspace_for_all_reduce_fusion( - ipc_handles, group=group - ) + # Destroy workspace - choose between legacy and new API + if legacy_api: + comm.trtllm_destroy_ipc_workspace_for_all_reduce_fusion( + ipc_handles, group=group + ) + elif workspace is not None: + # New unified API + workspace.destroy() dist.destroy_process_group(group=group) @@ -358,7 +454,8 @@ def multi_process_parallel( @pytest.mark.parametrize("world_size", [2, 4, 8]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_dim", [1024, 2048, 4096, 7168, 8192]) -def test_trtllm_allreduce_fusion(world_size, dtype, hidden_dim): +@pytest.mark.parametrize("legacy_api", [True, False]) +def test_trtllm_allreduce_fusion(world_size, dtype, hidden_dim, legacy_api): np.random.seed(42) torch.manual_seed(42) torch.cuda.manual_seed_all(42) @@ -367,17 +464,22 @@ def test_trtllm_allreduce_fusion(world_size, dtype, hidden_dim): pytest.skip( f"world_size {world_size} is greater than available_gpus {available_gpus}" ) - print(f"Running test for world_size={world_size}") + api_str = "legacy" if legacy_api else "unified" + print(f"Running test for world_size={world_size} with {api_str} API") multi_process_parallel( world_size, dtype, hidden_dim, _run_correctness_worker, - target_args=(), + target_args=(legacy_api,), ) - print(f"allreduce fusion tp = {world_size}: OK") + print(f"allreduce fusion tp = {world_size} ({api_str} API): OK") if __name__ == "__main__": - test_trtllm_allreduce_fusion(2, torch.float16, 1024) + # Test both legacy and unified APIs + print("Testing legacy API...") + test_trtllm_allreduce_fusion(2, torch.float16, 1024, legacy_api=True) + print("\nTesting unified API...") + test_trtllm_allreduce_fusion(2, torch.float16, 1024, legacy_api=False) From 3bb586bb9a3b6596df99074d060ac74a501e9fab Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 24 Nov 2025 09:18:46 -0800 Subject: [PATCH 05/11] Fixed unit test --- flashinfer/comm/allreduce.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 693d052723..89e5edfebb 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -182,7 +182,7 @@ def destroy(self) -> None: if self._destroyed: return # Already destroyed, nothing to do - trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self._internal_workspace) + trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self.ipc_handles) self._destroyed = True @property From 0bcc8da5958b951aaf1a8d2bae1349733285ef3a Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 24 Nov 2025 12:51:38 -0800 Subject: [PATCH 06/11] Relaxed check on trtllm_ar --- flashinfer/comm/trtllm_ar.py | 92 +++++++++++++++++++----------------- 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 33bb7ac97b..e0b4369f12 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -804,6 +804,51 @@ def _should_use_oneshot( return comm_size_mb <= _use_oneshot_heuristics[world_size] +def _check_workspace_metadata( + token_num: int, + hidden_dim: int, + world_size: int, + dtype: torch.dtype, + metadata: dict, +) -> None: + errors = [] + required_keys = ["max_token_num", "tp_size", "hidden_dim", "use_fp32_lamport"] + for key in required_keys: + if key not in metadata: + errors.append(f"Workspace metadata is missing required key: {key}") + if errors: + error_msg = "Workspace metadata validation failed:\n" + "\n".join( + f" - {e}" for e in errors + ) + raise ValueError(error_msg) + + # world_size must match tp_size (flag size depends on it) + if world_size != metadata["tp_size"]: + errors.append( + f"world_size ({world_size}) does not match workspace tp_size ({metadata['tp_size']}). " + f"Workspace was created for tp_size={metadata['tp_size']}." + ) + + # token_num * hidden_dim must not exceed max_token_num * hidden_dim + if token_num * hidden_dim > metadata["max_token_num"] * metadata["hidden_dim"]: + errors.append( + f"token_num ({token_num}) * hidden_dim ({hidden_dim}) exceeds workspace max_token_num ({metadata['max_token_num']}) * hidden_dim ({metadata['hidden_dim']}). " + f"This may cause Illegal Memory Access." + ) + + # use_fp32_lamport must match + if metadata["use_fp32_lamport"] != (dtype == torch.float32): + errors.append( + f"use_fp32_lamport ({metadata['use_fp32_lamport']}) does not match allreduce_in.dtype ({dtype}). " + f"Workspace was created for use_fp32_lamport={metadata['use_fp32_lamport']}." + ) + if errors: + error_msg = "Workspace validation failed:\n" + "\n".join( + f" - {e}" for e in errors + ) + raise ValueError(error_msg) + + def trtllm_allreduce_fusion( allreduce_in: torch.Tensor, world_size: int, @@ -858,50 +903,9 @@ def trtllm_allreduce_fusion( # Validate against workspace metadata if provided if metadata is not None: - errors = [] - required_keys = ["max_token_num", "tp_size", "hidden_dim", "use_fp32_lamport"] - for key in required_keys: - if key not in metadata: - errors.append(f"Workspace metadata is missing required key: {key}") - if errors: - error_msg = "Workspace metadata validation failed:\n" + "\n".join( - f" - {e}" for e in errors - ) - raise ValueError(error_msg) - - # Check 1: token_num must not exceed max_token_num - if token_num > metadata["max_token_num"]: - errors.append( - f"token_num ({token_num}) exceeds workspace max_token_num ({metadata['max_token_num']}). " - f"This may cause Illegal Memory Access." - ) - - # Check 2: world_size must match tp_size - if world_size != metadata["tp_size"]: - errors.append( - f"world_size ({world_size}) does not match workspace tp_size ({metadata['tp_size']}). " - f"Workspace was created for tp_size={metadata['tp_size']}." - ) - - # Check 3: hidden_dim must match - if hidden_dim != metadata["hidden_dim"]: - errors.append( - f"hidden_dim ({hidden_dim}) does not match workspace hidden_dim ({metadata['hidden_dim']}). " - f"Workspace was created for hidden_dim={metadata['hidden_dim']}." - ) - - # Check 4: use_fp32_lamport must match - if metadata["use_fp32_lamport"] != (allreduce_in.dtype == torch.float32): - errors.append( - f"use_fp32_lamport ({metadata['use_fp32_lamport']}) does not match allreduce_in.dtype ({allreduce_in.dtype}). " - f"Workspace was created for use_fp32_lamport={metadata['use_fp32_lamport']}." - ) - - if errors: - error_msg = "Workspace validation failed:\n" + "\n".join( - f" - {e}" for e in errors - ) - raise ValueError(error_msg) + _check_workspace_metadata( + token_num, hidden_dim, world_size, allreduce_in.dtype, metadata + ) if use_oneshot is None: use_oneshot = _should_use_oneshot( From 0c1391db6e3d4c84ce88270931d5dc8cf945d4d6 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 24 Nov 2025 15:01:21 -0800 Subject: [PATCH 07/11] Made metadata mandatory in unified API, added workspace check functions --- flashinfer/comm/allreduce.py | 77 +++++++++++++++++----- flashinfer/comm/trtllm_ar.py | 4 +- tests/comm/test_trtllm_allreduce_fusion.py | 1 - 3 files changed, 62 insertions(+), 20 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 89e5edfebb..1227dc4b06 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -48,7 +48,7 @@ >>> destroy_allreduce_fusion_workspace(workspace) """ -from typing import Union, Literal, Optional +from typing import Union, Literal, Optional, Tuple, List, cast from abc import ABC, abstractmethod import torch @@ -57,6 +57,11 @@ from .trtllm_ar import trtllm_allreduce_fusion from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion +from .trtllm_ar import check_trtllm_allreduce_fusion_workspace_metadata + +# Note: AllReduceFusionPattern and QuantizationSFLayout are pseudo-types (classes with int constants) +# Import them for runtime use but type hint as int for mypy compatibility +from .trtllm_ar import AllReduceFusionPattern # ============================================================================ @@ -161,13 +166,22 @@ def __init__( max_token_num=max_token_num, hidden_dim=hidden_dim, group=process_group, + create_metadata=True, **kwargs, ) # Store essential attributes for easy access - self.ipc_handles = self._internal_workspace[0] - self.workspace_tensor = self._internal_workspace[1] - self.metadata = self._internal_workspace[2] + # Cast to 3-tuple to make linter happy, since we always call with create_metadata=True + workspace_tuple = cast( + Tuple[List[List[int]], torch.Tensor, dict], self._internal_workspace + ) + self.ipc_handles = workspace_tuple[0] + self.workspace_tensor = workspace_tuple[1] + self.metadata = workspace_tuple[2] + + @property + def backend(self) -> str: + return "trtllm" def __getattr__(self, name): """Delegate attribute access to internal workspace if not found.""" @@ -177,6 +191,18 @@ def __getattr__(self, name): ) return getattr(self._internal_workspace, name) + def is_sufficient_for( + self, token_num: int, hidden_dim: int, tp_size: int, dtype: torch.dtype + ) -> bool: + try: + check_trtllm_allreduce_fusion_workspace_metadata( + token_num, hidden_dim, tp_size, dtype, self.metadata + ) + return True + except ValueError as e: + print(f"Workspace is insufficient for problem size. {e}") + return False + def destroy(self) -> None: """Destroy workspace and free resources.""" if self._destroyed: @@ -185,10 +211,6 @@ def destroy(self) -> None: trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self.ipc_handles) self._destroyed = True - @property - def backend(self) -> str: - return "trtllm" - class MNNVLAllReduceFusionWorkspace(AllReduceFusionWorkspace): """MNNVL workspace for AllReduce fusion.""" @@ -214,7 +236,6 @@ def __init__( **kwargs: Additional arguments for workspace creation """ super().__init__(world_size, rank) - # TODO: Import and call the actual MNNVL workspace creation function # For now, raise NotImplementedError raise NotImplementedError( @@ -239,6 +260,10 @@ def __init__( # self.buffer_M = self._internal_workspace.buffer_M # self.buffer_flags = self._internal_workspace.buffer_flags + @property + def backend(self) -> str: + return "mnnvl" + def destroy(self) -> None: """Destroy workspace and free resources.""" if self._destroyed: @@ -251,10 +276,6 @@ def destroy(self) -> None: # destroy_mnnvl_allreduce_fusion_workspace(self._internal_workspace) # self._destroyed = True - @property - def backend(self) -> str: - return "mnnvl" - # ============================================================================ # BACKEND CHECKS - Hard requirements for decorator @@ -413,6 +434,19 @@ def create_allreduce_fusion_workspace( Backend selection (checks + heuristics) handled by @backend_requirement decorator. + **Important: Workspace Reusability** + The workspace is allocated based on the total size (max_token_num * hidden_dim * dtype_size). + You can reuse the same workspace with different shapes as long as the total size fits: + + - Workspace(max_token_num=2048, hidden_dim=4096) can handle: + - (token_num=2048, hidden_dim=4096) ✓ + - (token_num=1024, hidden_dim=4096) ✓ + - (token_num=4096, hidden_dim=2048) ✓ (same total size) + - (token_num=1024, hidden_dim=8192) ✓ (same total size) + - (token_num=4096, hidden_dim=4096) ✗ (too large) + + Use `workspace.is_sufficient_for(token_num, hidden_dim, dtype)` to check before use. + Args: backend: Backend to use ("trtllm", "mnnvl", or "auto") "auto" uses heuristic to select best backend based on topology @@ -448,6 +482,11 @@ def create_allreduce_fusion_workspace( ... topology="single_node" ... ) >>> print(workspace.backend) # "trtllm" + >>> print(workspace.get_workspace_capacity()) # 8388608 elements + + >>> # Check if workspace can handle different problem sizes + >>> workspace.is_sufficient_for(1024, 4096, 8, torch.bfloat16) # True + >>> workspace.is_sufficient_for(4096, 2048, 8, torch.bfloat16) # True (same total) >>> # Explicit backend selection >>> workspace = create_allreduce_fusion_workspace( @@ -556,6 +595,10 @@ def allreduce_fusion( - AllReduce + Residual + RMSNorm - AllReduce + Residual + RMSNorm + Quantization (FP8/FP4) + **Note on Workspace Reusability:** + You can reuse the same workspace with different (token_num, hidden_dim) combinations + as long as `workspace.is_sufficient_for(token_num, hidden_dim, tp_size, dtype)` returns True. + Args: input: Input tensor [token_num, hidden_dim] workspace: Workspace object (type determines backend) @@ -685,9 +728,8 @@ def _infer_fusion_pattern( """ Automatically infer fusion pattern from provided tensors. - Returns AllReduceFusionPattern value based on which output tensors are provided. + Returns AllReduceFusionPattern value (as int) based on which output tensors are provided. """ - from .trtllm_ar import AllReduceFusionPattern if quant_out is not None: # Quantization patterns @@ -743,6 +785,7 @@ def _allreduce_fusion_trtllm( quant_out_flat = quant_out.flatten() if quant_out is not None else None # Call legacy API with flattened tensors + # Note: pattern and layout_code are ints but legacy API uses pseudo-type hints trtllm_allreduce_fusion( allreduce_in=input_flat, world_size=workspace.world_size, @@ -753,7 +796,7 @@ def _allreduce_fusion_trtllm( launch_with_pdl=launch_with_pdl, trigger_completion_at_end=launch_with_pdl, # Same meaning fp32_acc=fp32_acc, - pattern_code=pattern, + pattern_code=pattern, # type: ignore[arg-type] use_oneshot=use_oneshot, allreduce_out=output_flat, residual_in=residual_in_flat, @@ -764,7 +807,7 @@ def _allreduce_fusion_trtllm( rms_gamma=rms_gamma, # 1D tensor, no reshape needed rms_eps=rms_eps, scale_factor=scale_factor, - layout_code=layout_code, + layout_code=layout_code, # type: ignore[arg-type] metadata=metadata, ) diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index e0b4369f12..87246f739a 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -804,7 +804,7 @@ def _should_use_oneshot( return comm_size_mb <= _use_oneshot_heuristics[world_size] -def _check_workspace_metadata( +def check_trtllm_allreduce_fusion_workspace_metadata( token_num: int, hidden_dim: int, world_size: int, @@ -903,7 +903,7 @@ def trtllm_allreduce_fusion( # Validate against workspace metadata if provided if metadata is not None: - _check_workspace_metadata( + check_trtllm_allreduce_fusion_workspace_metadata( token_num, hidden_dim, world_size, allreduce_in.dtype, metadata ) diff --git a/tests/comm/test_trtllm_allreduce_fusion.py b/tests/comm/test_trtllm_allreduce_fusion.py index 17d9d7c2d4..31ddc1518b 100644 --- a/tests/comm/test_trtllm_allreduce_fusion.py +++ b/tests/comm/test_trtllm_allreduce_fusion.py @@ -86,7 +86,6 @@ def _run_correctness_worker( topology="single_node", process_group=group, use_fp32_lamport=lamport_use_fp32, - create_metadata=True, ) # Extract metadata for compatibility with tests workspace_metadata = workspace.metadata From b47ade43eaf89749661842c8258b2ad945c7c95d Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 24 Nov 2025 15:12:22 -0800 Subject: [PATCH 08/11] Merged dtype and use_fp32_lamport params --- flashinfer/comm/allreduce.py | 8 ++++---- tests/comm/test_trtllm_allreduce_fusion.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 1227dc4b06..70a307187e 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -142,8 +142,8 @@ def __init__( tp_rank: int, max_token_num: int, hidden_dim: int, + dtype: torch.dtype = torch.float16, process_group: Optional["torch.distributed.ProcessGroup"] = None, - **kwargs, ): """ Create TensorRT-LLM AllReduce fusion workspace. @@ -167,7 +167,7 @@ def __init__( hidden_dim=hidden_dim, group=process_group, create_metadata=True, - **kwargs, + use_fp32_lamport=dtype == torch.float32, ) # Store essential attributes for easy access @@ -427,7 +427,7 @@ def create_allreduce_fusion_workspace( dtype: torch.dtype = None, topology: str = "single_node", process_group: Optional["torch.distributed.ProcessGroup"] = None, - **backend_kwargs, + **backend_kwargs, # TODO(nvmbreughe): remove this ) -> AllReduceFusionWorkspace: """ Create workspace for AllReduce fusion operations. @@ -517,8 +517,8 @@ def create_allreduce_fusion_workspace( tp_rank=rank, max_token_num=max_token_num, hidden_dim=hidden_dim, + dtype=dtype, process_group=process_group, - **backend_kwargs, ) elif actual_backend == "mnnvl": diff --git a/tests/comm/test_trtllm_allreduce_fusion.py b/tests/comm/test_trtllm_allreduce_fusion.py index 31ddc1518b..601bddbb91 100644 --- a/tests/comm/test_trtllm_allreduce_fusion.py +++ b/tests/comm/test_trtllm_allreduce_fusion.py @@ -85,7 +85,6 @@ def _run_correctness_worker( dtype=dtype, topology="single_node", process_group=group, - use_fp32_lamport=lamport_use_fp32, ) # Extract metadata for compatibility with tests workspace_metadata = workspace.metadata From 2d002673264eb3d3c554ca216e0a4f1c427d26c9 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 1 Dec 2025 13:29:02 -0800 Subject: [PATCH 09/11] removed useless function --- flashinfer/comm/__init__.py | 3 -- flashinfer/comm/allreduce.py | 53 ++++++++++-------------------------- 2 files changed, 15 insertions(+), 41 deletions(-) diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index 3ffad2505d..8a1d6f9f0c 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -47,8 +47,5 @@ from .allreduce import ( create_allreduce_fusion_workspace as create_allreduce_fusion_workspace, ) -from .allreduce import ( - destroy_allreduce_fusion_workspace as destroy_allreduce_fusion_workspace, -) # from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 70a307187e..54f1b17c8b 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -32,20 +32,20 @@ ... topology="single_node" ... ) >>> - >>> # Perform AllReduce + RMSNorm fusion - >>> prenorm = torch.empty_like(hidden_states) - >>> normed = torch.empty_like(hidden_states) - >>> output = allreduce_fusion( - ... input=hidden_states, - ... workspace=workspace, - ... launch_with_pdl=True, - ... residual_out=prenorm, - ... norm_out=normed, - ... residual_in=residual, - ... rms_gamma=norm_weight - ... ) - >>> - >>> destroy_allreduce_fusion_workspace(workspace) + >>> # Perform AllReduce + RMSNorm fusion + >>> prenorm = torch.empty_like(hidden_states) + >>> normed = torch.empty_like(hidden_states) + >>> output = allreduce_fusion( + ... input=hidden_states, + ... workspace=workspace, + ... launch_with_pdl=True, + ... residual_out=prenorm, + ... norm_out=normed, + ... residual_in=residual, + ... rms_gamma=norm_weight + ... ) + >>> + >>> workspace.destroy() """ from typing import Union, Literal, Optional, Tuple, List, cast @@ -76,7 +76,7 @@ # 1. Calls the backend-specific workspace creation function in __init__ # 2. Stores the internal workspace as _internal_workspace # 3. Exposes essential attributes for the unified API -# 4. Can be destroyed using destroy_allreduce_fusion_workspace() +# 4. Can be destroyed using workspace.destroy() # ============================================================================ @@ -534,29 +534,6 @@ def create_allreduce_fusion_workspace( raise RuntimeError(f"Unknown backend: {actual_backend}") -# ============================================================================ -# WORKSPACE DESTRUCTION -# ============================================================================ - - -def destroy_allreduce_fusion_workspace(workspace: AllReduceFusionWorkspace) -> None: - """ - Destroy workspace and free resources. - - This is a convenience function that calls the workspace's destroy() method. - - Args: - workspace: Workspace object to destroy - - Example: - >>> workspace = create_allreduce_fusion_workspace(...) - >>> # ... use workspace ... - >>> destroy_allreduce_fusion_workspace(workspace) - >>> # Or call directly: workspace.destroy() - """ - workspace.destroy() - - # ============================================================================ # MAIN API - NO backend parameter, infers from workspace type # ============================================================================ From 7001c9219a0cae127ca10fce14c1aa29a37e6c42 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 1 Dec 2025 13:41:53 -0800 Subject: [PATCH 10/11] Moved in the helper functions, rejected some patterns for mnnvl --- flashinfer/comm/allreduce.py | 250 ++++++++++++----------------------- 1 file changed, 84 insertions(+), 166 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 54f1b17c8b..af8ae3433a 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -651,183 +651,101 @@ def allreduce_fusion( """ # Auto-detect pattern if not provided if pattern is None: - pattern = _infer_fusion_pattern( - output, residual_in, residual_out, norm_out, quant_out, scale_out - ) + if quant_out is not None: + # Quantization patterns + if norm_out is not None and residual_out is not None: + pattern = AllReduceFusionPattern.kARResidualRMSNormOutFP8Quant # 4 + else: + pattern = AllReduceFusionPattern.kARResidualRMSNormFP8Quant # 2 + elif norm_out is not None: + pattern = AllReduceFusionPattern.kARResidualRMSNorm # 1 + else: + pattern = AllReduceFusionPattern.kAllReduce # 0 - # Infer backend from workspace type and dispatch + # Dispatch based on workspace type if isinstance(workspace, TRTLLMAllReduceFusionWorkspace): - return _allreduce_fusion_trtllm( - input=input, - workspace=workspace, + # TensorRT-LLM backend implementation + # Extract shape from 2D input + token_num, hidden_dim = input.shape + + # Allocate output if needed (keep 2D shape) + if output is None: + output = torch.empty_like(input) + + # Flatten all tensors to 1D for legacy trtllm_allreduce_fusion API + # The legacy API expects flattened tensors and explicit token_num/hidden_dim + input_flat = input.flatten() + output_flat = output.flatten() + residual_in_flat = residual_in.flatten() if residual_in is not None else None + residual_out_flat = residual_out.flatten() if residual_out is not None else None + norm_out_flat = norm_out.flatten() if norm_out is not None else None + quant_out_flat = quant_out.flatten() if quant_out is not None else None + + # Call legacy API with flattened tensors + # Note: pattern and layout_code are ints but legacy API uses pseudo-type hints + trtllm_allreduce_fusion( + allreduce_in=input_flat, + world_size=workspace.world_size, + world_rank=workspace.rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace.workspace_tensor, launch_with_pdl=launch_with_pdl, - output=output, - residual_in=residual_in, - residual_out=residual_out, - norm_out=norm_out, - quant_out=quant_out, - scale_out=scale_out, - rms_gamma=rms_gamma, + trigger_completion_at_end=launch_with_pdl, # Same meaning + fp32_acc=fp32_acc, + pattern_code=pattern, # type: ignore[arg-type] + use_oneshot=use_oneshot, + allreduce_out=output_flat, + residual_in=residual_in_flat, + residual_out=residual_out_flat, + norm_out=norm_out_flat, + quant_out=quant_out_flat, + scale_out=scale_out, # scale_out is not reshaped + rms_gamma=rms_gamma, # 1D tensor, no reshape needed rms_eps=rms_eps, scale_factor=scale_factor, - layout_code=layout_code, - pattern=pattern, - use_oneshot=use_oneshot, - fp32_acc=fp32_acc, + layout_code=layout_code, # type: ignore[arg-type] metadata=metadata, ) - elif isinstance(workspace, MNNVLAllReduceFusionWorkspace): - return _allreduce_fusion_mnnvl( - input=input, - workspace=workspace, - launch_with_pdl=launch_with_pdl, - residual_in=residual_in, - residual_out=residual_out, - norm_out=norm_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - ) - else: - raise TypeError( - f"Unknown workspace type: {type(workspace)}. " - f"Expected TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace" - ) - - -# ============================================================================ -# HELPER FUNCTIONS -# ============================================================================ - - -def _infer_fusion_pattern( - output, residual_in, residual_out, norm_out, quant_out, scale_out -) -> int: - """ - Automatically infer fusion pattern from provided tensors. - Returns AllReduceFusionPattern value (as int) based on which output tensors are provided. - """ - - if quant_out is not None: - # Quantization patterns - if norm_out is not None and residual_out is not None: - # Has separate norm output and residual output - return AllReduceFusionPattern.kARResidualRMSNormOutFP8Quant # 4 + # Return the most downstream output (already in 2D shape from input views) + if norm_out is not None: + return norm_out + elif quant_out is not None: + return quant_out else: - # Quant without separate outputs - return AllReduceFusionPattern.kARResidualRMSNormFP8Quant # 2 - elif norm_out is not None: - # RMS Norm without quantization - return AllReduceFusionPattern.kARResidualRMSNorm # 1 - else: - # Just AllReduce - return AllReduceFusionPattern.kAllReduce # 0 + return output + elif isinstance(workspace, MNNVLAllReduceFusionWorkspace): + if ( + pattern != AllReduceFusionPattern.kARResidualRMSNorm + and pattern != AllReduceFusionPattern.kAllReduce + ): + raise ValueError( + f"MNNVL AllReduce+RMS fusion does not support pattern {pattern}" + ) + # MNNVL backend implementation + # Validate required parameters for RMS fusion + if residual_in is None: + raise ValueError("MNNVL AllReduce+RMS fusion requires residual_in") + if residual_out is None: + raise ValueError( + "MNNVL AllReduce+RMS fusion requires residual_out (prenorm_output)" + ) + if norm_out is None: + raise ValueError( + "MNNVL AllReduce+RMS fusion requires norm_out (normed_output)" + ) + if rms_gamma is None: + raise ValueError("MNNVL AllReduce+RMS fusion requires rms_gamma") -def _allreduce_fusion_trtllm( - input: torch.Tensor, - workspace: TRTLLMAllReduceFusionWorkspace, - launch_with_pdl: bool, - output: Optional[torch.Tensor], - residual_in: Optional[torch.Tensor], - residual_out: Optional[torch.Tensor], - norm_out: Optional[torch.Tensor], - quant_out: Optional[torch.Tensor], - scale_out: Optional[torch.Tensor], - rms_gamma: Optional[torch.Tensor], - rms_eps: float, - scale_factor: Optional[Union[torch.Tensor, float]], - layout_code: Optional[int], - pattern: int, - use_oneshot: Optional[bool], - fp32_acc: bool, - metadata: Optional[dict], -) -> torch.Tensor: - """TensorRT-LLM backend implementation.""" - - # Extract shape from 2D input - token_num, hidden_dim = input.shape - - # Allocate output if needed (keep 2D shape) - if output is None: - output = torch.empty_like(input) - - # Flatten all tensors to 1D for legacy trtllm_allreduce_fusion API - # The legacy API expects flattened tensors and explicit token_num/hidden_dim - input_flat = input.flatten() - output_flat = output.flatten() - residual_in_flat = residual_in.flatten() if residual_in is not None else None - residual_out_flat = residual_out.flatten() if residual_out is not None else None - norm_out_flat = norm_out.flatten() if norm_out is not None else None - quant_out_flat = quant_out.flatten() if quant_out is not None else None - - # Call legacy API with flattened tensors - # Note: pattern and layout_code are ints but legacy API uses pseudo-type hints - trtllm_allreduce_fusion( - allreduce_in=input_flat, - world_size=workspace.world_size, - world_rank=workspace.rank, - token_num=token_num, - hidden_dim=hidden_dim, - workspace_ptrs=workspace.workspace_tensor, - launch_with_pdl=launch_with_pdl, - trigger_completion_at_end=launch_with_pdl, # Same meaning - fp32_acc=fp32_acc, - pattern_code=pattern, # type: ignore[arg-type] - use_oneshot=use_oneshot, - allreduce_out=output_flat, - residual_in=residual_in_flat, - residual_out=residual_out_flat, - norm_out=norm_out_flat, - quant_out=quant_out_flat, - scale_out=scale_out, # scale_out is not reshaped - rms_gamma=rms_gamma, # 1D tensor, no reshape needed - rms_eps=rms_eps, - scale_factor=scale_factor, - layout_code=layout_code, # type: ignore[arg-type] - metadata=metadata, - ) - - # Return the most downstream output (already in 2D shape from input views) - if norm_out is not None: - return norm_out - elif quant_out is not None: - return quant_out - else: - return output - + # Call the MNNVL fusion function + raise NotImplementedError("MNNVL AllReduce+RMS fusion is not implemented") -def _allreduce_fusion_mnnvl( - input: torch.Tensor, - workspace: MNNVLAllReduceFusionWorkspace, - launch_with_pdl: bool, - residual_in: Optional[torch.Tensor], - residual_out: Optional[torch.Tensor], - norm_out: Optional[torch.Tensor], - rms_gamma: Optional[torch.Tensor], - rms_eps: float, -) -> torch.Tensor: - """ - MNNVL backend implementation. + return norm_out - Calls trtllm_mnnvl_fused_allreduce_rmsnorm which performs: - 1. AllReduce on input - 2. Add residual - 3. RMSNorm - """ - # Validate required parameters for RMS fusion - if residual_in is None: - raise ValueError("MNNVL AllReduce+RMS fusion requires residual_in") - if residual_out is None: - raise ValueError( - "MNNVL AllReduce+RMS fusion requires residual_out (prenorm_output)" + else: + raise TypeError( + f"Unknown workspace type: {type(workspace)}. " + f"Expected TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace" ) - if norm_out is None: - raise ValueError("MNNVL AllReduce+RMS fusion requires norm_out (normed_output)") - if rms_gamma is None: - raise ValueError("MNNVL AllReduce+RMS fusion requires rms_gamma") - - # Call the MNNVL fusion function - raise NotImplementedError("MNNVL AllReduce+RMS fusion is not implemented") - - return norm_out From e2fdea22ca617c297f1301658c2d8e5c2ea8fd1e Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 1 Dec 2025 13:46:53 -0800 Subject: [PATCH 11/11] Made fusion pattern param mandatory --- flashinfer/comm/allreduce.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index af8ae3433a..5fb28753bf 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -542,6 +542,7 @@ def create_allreduce_fusion_workspace( def allreduce_fusion( input: torch.Tensor, workspace: AllReduceFusionWorkspace, + pattern: int, launch_with_pdl: bool = False, # ===== OUTPUT tensors (pre-allocated, will be filled) ===== output: Optional[torch.Tensor] = None, @@ -556,7 +557,6 @@ def allreduce_fusion( scale_factor: Optional[Union[torch.Tensor, float]] = None, layout_code: Optional[int] = None, # ===== Control parameters ===== - pattern: Optional[int] = None, use_oneshot: Optional[bool] = None, fp32_acc: bool = False, metadata: Optional[dict] = None, @@ -579,6 +579,14 @@ def allreduce_fusion( Args: input: Input tensor [token_num, hidden_dim] workspace: Workspace object (type determines backend) + pattern: Fusion pattern (AllReduceFusionPattern constant, 0-5) + - kAllReduce = 0 + - kARResidualRMSNorm = 1 + - kARResidualRMSNormFP8Quant = 2 + - kARResidualRMSNormFP4Quant = 3 + - kARResidualRMSNormOutFP8Quant = 4 + - kARResidualRMSNormOutFP4Quant = 5 + Note: MNNVL only supports patterns 0 and 1 launch_with_pdl: Use Persistent Dependency Launch # ===== OUTPUT tensors (pre-allocated, filled by function) ===== @@ -596,8 +604,6 @@ def allreduce_fusion( layout_code: Scale factor layout (QuantizationSFLayout) [trtllm only] # ===== Control parameters ===== - pattern: Fusion pattern (AllReduceFusionPattern) - If None, auto-detected based on provided output tensors use_oneshot: [trtllm only] Use oneshot strategy vs twoshot If None, uses internal heuristics fp32_acc: [trtllm only] Use FP32 accumulation for AllReduce @@ -626,6 +632,7 @@ def allreduce_fusion( >>> output = allreduce_fusion( ... input=hidden_states, ... workspace=workspace, + ... pattern=AllReduceFusionPattern.kARResidualRMSNorm, ... launch_with_pdl=True, ... residual_out=prenorm, ... norm_out=normed, @@ -641,6 +648,7 @@ def allreduce_fusion( >>> output = allreduce_fusion( ... input=hidden_states, ... workspace=workspace, + ... pattern=AllReduceFusionPattern.kARResidualRMSNormFP8Quant, ... norm_out=normed, ... quant_out=quant, ... scale_out=scales, @@ -649,19 +657,6 @@ def allreduce_fusion( ... scale_factor=scale_tensor ... ) """ - # Auto-detect pattern if not provided - if pattern is None: - if quant_out is not None: - # Quantization patterns - if norm_out is not None and residual_out is not None: - pattern = AllReduceFusionPattern.kARResidualRMSNormOutFP8Quant # 4 - else: - pattern = AllReduceFusionPattern.kARResidualRMSNormFP8Quant # 2 - elif norm_out is not None: - pattern = AllReduceFusionPattern.kARResidualRMSNorm # 1 - else: - pattern = AllReduceFusionPattern.kAllReduce # 0 - # Dispatch based on workspace type if isinstance(workspace, TRTLLMAllReduceFusionWorkspace): # TensorRT-LLM backend implementation