This commit is contained in:
Wenxuan Tan 2025-07-15 14:29:39 +08:00 committed by GitHub
commit ddaab46e27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -406,13 +406,18 @@ def _rescale_out_lse(out, block_out, lse, block_lse):
class RingAttention(torch.autograd.Function):
"""Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context`
(https://arxiv.org/abs/2310.01889).
For load-balancing we adopted the "zigzag" attention scheme from https://github.com/zhuzilin/ring-flash-attention/tree/main
For portable integration with more models, we don't follow the spirit of "block-wise FNN" in the original paper,
which requires fusing FFN with the Flash Attention kernel/function (see https://arxiv.org/pdf/2305.19370;
implemented in Jax and not optimized).
We adopt the double ring topology from LoongTrain (https://arxiv.org/pdf/2406.18485) to fully utilize available
For load-balancing, we adopted the "zigzag" dataloading scheme from ring-flash-attention.
We also adopt the double ring topology from LoongTrain to fully utilize available
NICs on each node, by computing attention within a inner ring first and then sending all KVs to the next
ring at once.
Our implementation references code from
- ring-flash-attention: https://github.com/zhuzilin/ring-flash-attention/tree/main
- Megatron Context Parallel: https://github.com/NVIDIA/TransformerEngine/pull/726
References:
- Ring Attention with Blockwise Transformers for Near-Infinite Context
https://arxiv.org/abs/2310.01889
- LoongTrain: Efficient Training of Long-Sequence LLMs with Head-Context Parallelism
https://arxiv.org/abs/2406.18485
"""
# Globle cache to avoid recomputation for same-lengthed sequences