mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-22 11:13:13 +00:00
Merge 35f45ffd36
into edd65a84dd
This commit is contained in:
commit
ddaab46e27
@ -406,13 +406,18 @@ def _rescale_out_lse(out, block_out, lse, block_lse):
|
||||
class RingAttention(torch.autograd.Function):
|
||||
"""Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context`
|
||||
(https://arxiv.org/abs/2310.01889).
|
||||
For load-balancing we adopted the "zigzag" attention scheme from https://github.com/zhuzilin/ring-flash-attention/tree/main
|
||||
For portable integration with more models, we don't follow the spirit of "block-wise FNN" in the original paper,
|
||||
which requires fusing FFN with the Flash Attention kernel/function (see https://arxiv.org/pdf/2305.19370;
|
||||
implemented in Jax and not optimized).
|
||||
We adopt the double ring topology from LoongTrain (https://arxiv.org/pdf/2406.18485) to fully utilize available
|
||||
For load-balancing, we adopted the "zigzag" dataloading scheme from ring-flash-attention.
|
||||
We also adopt the double ring topology from LoongTrain to fully utilize available
|
||||
NICs on each node, by computing attention within a inner ring first and then sending all KVs to the next
|
||||
ring at once.
|
||||
Our implementation references code from
|
||||
- ring-flash-attention: https://github.com/zhuzilin/ring-flash-attention/tree/main
|
||||
- Megatron Context Parallel: https://github.com/NVIDIA/TransformerEngine/pull/726
|
||||
References:
|
||||
- Ring Attention with Blockwise Transformers for Near-Infinite Context
|
||||
https://arxiv.org/abs/2310.01889
|
||||
- LoongTrain: Efficient Training of Long-Sequence LLMs with Head-Context Parallelism
|
||||
https://arxiv.org/abs/2406.18485
|
||||
"""
|
||||
|
||||
# Globle cache to avoid recomputation for same-lengthed sequences
|
||||
|
Loading…
Reference in New Issue
Block a user