A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Mask Training
- [2025/5] We support overlapped q_ranges when all mask types are
FULL(see v1.0.1 release note for more details), and release the example code to integrate Megatron with MagiAttention with several training convergence experiments (see here for more details). - [2025/4] π We release MagiAttention-v1.0.0 with its blog: a distributed attention towards linear scalability for ultra-long context, heterogeneous mask training.
MagiAttention is a distributed attention mechanism, or context-parallel (CP) strategy, which aims to support a wide variety of attention mask types with kernel-level flexibility, while achieving linear scalability with respect to context-parallel (CP) size across a broad range of scenarios, particularly suitable for training tasks involving ultra-long, heterogeneous mask training like video-generation for Magi-1.
Additionally, it can be easily integrated into prevalent training frameworks such as Megatron-LM and Pytorch's native FSDP, as illustrated in QuickStart.
We are committed to continually improving the performance and generality of MagiAttention for the broader research community. Stay tuned for exciting enhancements and new features on the horizon!
To realize linear scalability for distributed attention, we implement and introduce key designs as follows.
For implementation details, more experimental results and future works, please visit our blog.
- Flexible Flash Attention Kernel. We introduce a generalized formulation for irregular attention mask patterns and implement a flexible flash attention kernel (FFA). It is natively designed for distribution scenarios and provides greater flexibility in handling diverse attention mask types, with performance comparable to Flash-Attention 3 on Hopper GPUs.
- Computation Load-Balance. With a fine-grained sharding strategy, we elaborate an efficient dispatch solver that ensures balanced attention computational loads across each CP rank in every training iteration.
- Zero-Redundant Communication. Instead of adopting the common Ring-style P2P communication pattern in CP, we propose two novel communication primitives, GroupCast and GroupReduce, built upon All-to-All-v as a prototypal implementation, enabling zero-redundant communication volume for both forward and backward passes.
- Adaptive Multi-Stage Overlap. Leveraging the above enhancements, we further implement a multi-stage compute-communication overlap strategy that effectively hides communication latency and adaptively optimizes overlap through manual or automatic tuning.
- Optimize
Flex-Flash-Attentionkernels to improve performance and better support sparse attention (such as NSA) - Support native
GroupCastandGroupReducekernels and hierarchical communication optimization (similar to DeepEP) - Refactor
Distributed Attention Solveras well asFlex-Flash-Attentionkernel arguments to support all mask types with all kinds of overlap, and reduce CPU overhead for meta info calculation - Improve
Dispatch Solverto reduce necessary communication volumn while remaining balance in computation (especially for varlen mask patterns) - Build a comprehensive
CP Benchmarkto better compare the performance of different context parallel strategies under various mask patterns and other training configurations
-
release note: here
-
docker image version: nvcr.io/nvidia/pytorch:25.02-py3
-
docker run command:
docker run --name {container_name} -v {host_mnt_root}:{container_mnt_root} -it -d --privileged --gpus all --network host --ipc host --ulimit memlock=-1 --ulimit stack=67108864 nvcr.io/nvidia/pytorch:25.02-py3 /bin/bash -
docker exec command:
docker exec -it {container_name} /bin/bash
-
command:
pip install -r requirements.txt
-
command:
git clone https://github.com/SandAI-org/MagiAttention.git cd MagiAttention git submodule update --init --recursive pip install --no-build-isolation .
Warning
MagiAttention currently only supports Hopper GPUs. We intend to broaden this support in upcoming updates.
We provide an example(pseudo-code) of how to use flex_flash_attention(kernel) and magi_attention(context parallel only) to accelerate local/distribute attention calculation.
You can refer to the magi_attention/api/magi_attn_interface.py for more information.
Basic Usage
flex_flash_attention(kernel):
from magi_attention.api import flex_flash_attn_func
# --- Define Attention Structure ---
device='cuda'
# Shape: [num_ranges, 2]
q_ranges_tensor = torch.tensor([[0, 100], [100, 250]], device=device, dtype=torch.int32)
k_ranges_tensor = torch.tensor([[0, 100], [0, 250]], device=device, dtype=torch.int32)
max_seqlen_q = 150 # Max length of any q_range (250-100 = 150)
max_seqlen_k = 250 # Max length of any k_range (250-0 = 250)
# attn_type_map values:
# 0: full attention
# 1: causal attention (bottom-right aligned)
# 2: inverse causal attention (top-left aligned)
# 3: bidirectional causal attention (diagonal)
# for more information about attn mask type, please refer to our blog:
# https://sandai-org.github.io/MagiAttention/
attn_type_map_tensor = torch.tensor([1, 0], device=device, dtype=torch.int32) # Causal for 1st, Full for 2nd
# --- Forward Pass ---
# disable_fwd_atomic_reduction=True can be used if q_ranges are guaranteed to be non-overlapping for performance.
# If q_ranges might overlap (e.g. for specific sparse patterns not representable as disjoint blocks), set it to False.
out_ffa, lse_ffa = flex_flash_attn_func(
q, k, v,
q_ranges=q_ranges_tensor,
k_ranges=k_ranges_tensor,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
attn_type_map=attn_type_map_tensor,
softmax_scale=None, # Defaults to 1/sqrt(head_dim)
disable_fwd_atomic_reduction=True # Assuming q_ranges here are disjoint after any potential processing
)flash_attn_varlen like interface(magi_attn_varlen_dispatch):
from magi_attention.api import magi_attn_varlen_dispatch, undispatch, calc_attn, squash_batch_dim, full_attention_to_varlen_attention, compute_pad_size # func tools and interface
# --- prepare data and args for magi_attention --- #
# create input data with shape (bs, seqlen, h)
x = torch.randn(
batchsize,
seqlen,
h,
device=device,
dtype=dtype,
requires_grad = True
)
# squash the batch dim, magi_attention do not support input data with batch dim.
x = squash_batch_dim(x_with_batch) # ((b, seqlen), h)
# get cu_seqlens_q,k after squashing.
cu_seqlens_q, cu_seqlens_k = full_attention_to_varlen_attention(
batch_size, seqlen
)
# pad input seqlen for better performance
pad_size, _ = compute_pad_size(x, cp_size, head_dim)
total_seqlen_q: int = batchsize * seqlen
total_seqlen_k: int = batchsize * seqlen
# --- magi_attention dispatch --- #
# dispatch global input tensor to each rank and get the runtime_key
local_x, magi_attn_runtime_key = magi_attn_varlen_dispatch( # local_x with shape ((total_seq + pad_size) / cp_size), h)
x,
cu_seqlens_q,
cu_seqlens_k,
head_dim=head_dim,
pad_size=pad_size,
cp_group=cp_group,
causal=False,
dist_attn_config=DistAttnConfig(
dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()),
overlap_config=OverlapConfig(
enable=True,
mode=AttnOverlapMode.STATIC,
degree=2,
min_chunk_size=512,
max_num_chunks=64,
alg=OverlapAlgType.UNIFORM,
),
),
)
......
# --- magi_attention calculation and undispatch --- #
# do q k v projection
local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x) # q, k, v with shape (bs * seqlen / cp_size, nh, hd)
# Do local attention computation with runtime key
local_out, _ = calc_attn(local_q, local_k, local_v, magi_attn_runtime_key) # local out with shape (bs * seqlen / cp_size, h)
# Gather local attention results to global result with runtime key
total_out = undispatch(local_out, magi_attn_runtime_key) # total out with shape (bs * seqlen, h)magi_attn_flex_dispatch(more flexible):
from magi_attention.api import magi_attn_flex_dispatch, undispatch, calc_attn, squash_batch_dim, full_attention_to_varlen_attention, compute_pad_size # func tools and interface
x = torch.randn(
seqlen,
h,
device=device,
dtype=dtype,
requires_grad = True
)
# block mask
q_ranges = AttnRanges.from_ranges(
[
[0, 128],
[128, 256],
[256, 384],
[384, 512],
[512, 640],
[640, 768],
[768, 960],
]
),
k_ranges = AttnRanges.from_ranges(
[
[0, 128],
[0, 256],
[0, 384],
[0, 512],
[512, 640],
[512, 768],
[768, 960],
]
),
total_seqlen_q = 960
total_seqlen_k = 960
attn_mask_type = [AttnMaskType.FULL] * 7
pad_size, _ = compute_pad_size(total_seqlen_q, cp_size, head_dim)
local_x, magi_attn_runtime_key = magi_attn_flex_dispatch( # local_x with shape (total_seqlen_q + pad_size) / cp_size, h)
x,
q_ranges=q_ranges,
k_ranges=k_ranges,
attn_mask_type=attn_mask_type,
total_seqlen_q=total_seqlen_q,
total_seqlen_k=total_seqlen_k,
head_dim=head_dim,
pad_size=pad_size,
cp_group=self.nccl_group,
is_same_source=True,
is_q_permutable=True,
is_k_permutable=True,
dist_attn_config=dist_attn_config,
)
......
# --- magi_attention calculation and undispatch --- #
# do q k v projection
local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x) # q, k, v with shape (s, nh, hd)
# Do local attention computation with runtime key
local_out, _ = calc_attn(local_q, local_k, local_v, magi_attn_runtime_key) # local out with shape (s, h)
# Gather local attention results and unpad to global result with runtime key
total_out = undispatch(local_out, magi_attn_runtime_key) # total out with shape (totoal_seqlen_q, h)We provide an example of how to integrate magi_attention with fsdp2 in example/torch_native. You can use bash run.sh to run the example.
In this example, we build a llama-1b model and apply fsdp2 with magi_attention as the parallelism strategy.
example/torch_native/modeling_llama.py: build llama model and integrate with magi_attention.example/torch_native/main.py: main training loop.
We create a new repository Megatron-LM-MagiAttention, forked from Megatron-LM v0.11.0, to provide an example of training the llama-1B model with Megatron-LM + MagiAttention. What's more, we conducted an experiment training llama-3-1B model from scratch to show the correctness of convergence.
For more information, you can refer to example/megatron/README.md.
Coming soon ...
To demonstrate FFA kernels' state-of-the-art performance and flexibility in handling ultra-long, heterogeneous mask training, we measure the computing power (in
| settings | value |
|---|---|
| batch size (b) | 1 |
| number of heads (nh) | nhq:nhk:nhv = 64:8:8 (GQA) |
| head dimension (hd) | 128 |
| dtype | torch.bfloat16 |
| dropout probability | 0.0 |
| window size | 1024 (for sliding window masks only) |
Benchmark settings: for each mask pattern, we vary the sequence length seqlen from seqlen_q = seqlen_k = seqlen) while measuring computation power (in seqlen.
Some Results are reported in the following figures, see more in our blog.
To validate the scalability of MagiAttention, we assess the per-GPU computing power (in
The experiments are conducted on a large-scale productive GPU cluster (Due to business and confidentiality reasons, specific details about the productive cluster, such as the number and type of GPUs, are withheld.). We scale the total sequence length seqlen, the context-parallel size cp_size, and the node size nnodes together from seqlen:64k, cp_size:1, nnodes:1, seqlen:128k, cp_size:2, nnodes:2, ..., to seqlen:3072k (3M), cp_size:48, nnodes:48.
The tensor-parallel size tp_size is fixed at 8, with sequence-parallel enabled. Other data and model configurations for different mask types are the same as in the table in Kernel-Level Experiments.
Therefore, in every training setting, each rank is assigned constantly with seqlen=64k, num_heads_q = 8 and num_heads_k = 1 for attention propagation, while the remaining activations stays seqlen=8k, num_heads_q = 64 and num_heads_k = 8 with SP enabled. This setup simulates a common training configuration.
Some of the results are presented in the following figures, see more in our blog.
As demonstrated, MagiAttention exhibits linear scalability as the context length and CP size increase, in both full mask and varlen full mask configurations, for both forward and backward passes. In contrast, baseline methods either face strict limitations in scaling up or experience performance degradation with ultra-long contexts, which worsens with varlen mask patterns.
We welcome and value any contributions and collaborations. Please check out CONTRIBUTING.md for how to get involved.
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
If you use MagiAttention in your research, please cite:
@misc{magiattention2025,
title={MagiAttention: A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Mask Training},
author={Zewei, Tao and Yunpeng, Huang},
year={2025},
howpublished={\url{https://github.com/SandAI-org/MagiAttention/}},
}We are grateful to the contributors listed below for their valuable contributions during the early stages of MagiAttention.
| Member | Affiliations | GitHub Account | |
|---|---|---|---|
| Zewei Tao | SandAI | zeweitao@sand.ai | littsk |
| Yunpeng Huang | SandAI, Nanjing University | yunpenghuang@sand.ai,hyp@smail.nju.edu.cn | Strivin0311 |
| Qiangang Wang | Nanjing University | 522024330081@smail.nju.edu.cn | WT1W |
| Hanwen Sun | SandAI, Peking University | sunhanwen@stu.pku.edu.cn | hanwen-sun |
| Tao Bu | Nanjing University | 502024330002@smail.nju.edu.cn | Big-TRex |
| WenYang Fang | Nanjing University | fwy@smail.nju.edu.cn | kagami4243 |
| Siyuang Yan | Nanjing University | siyuanyan@smail.nju.edu.cn | FibonaccciYan |
| Zixu Jiang | Nanjing University | 522023330040@smail.nju.edu.cn | 191220042 |
| Dingkun Xu | Nanjing University | 211220090@smail.nju.edu.cn | PureDimension |
| Mingyu Liang | Nanjing University | mingyuliang518@gmail.com | gaomusiki |
| Jingwei Xu | Nanjing University | jingweix@nju.edu.cn | paragonlight |
