Skip to content
9 changes: 9 additions & 0 deletions flashinfer/comm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,13 @@
from .vllm_ar import register_buffer as vllm_register_buffer
from .vllm_ar import register_graph_buffers as vllm_register_graph_buffers

# Unified AllReduce Fusion API
from .allreduce import AllReduceFusionWorkspace as AllReduceFusionWorkspace
from .allreduce import MNNVLAllReduceFusionWorkspace as MNNVLAllReduceFusionWorkspace
from .allreduce import TRTLLMAllReduceFusionWorkspace as TRTLLMAllReduceFusionWorkspace
from .allreduce import allreduce_fusion as allreduce_fusion
from .allreduce import (
create_allreduce_fusion_workspace as create_allreduce_fusion_workspace,
)

# from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
Loading