mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-27 21:40:19 +00:00
Merge 35f45ffd36
into edd65a84dd
This commit is contained in:
commit
ddaab46e27
@ -406,13 +406,18 @@ def _rescale_out_lse(out, block_out, lse, block_lse):
|
|||||||
class RingAttention(torch.autograd.Function):
|
class RingAttention(torch.autograd.Function):
|
||||||
"""Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context`
|
"""Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context`
|
||||||
(https://arxiv.org/abs/2310.01889).
|
(https://arxiv.org/abs/2310.01889).
|
||||||
For load-balancing we adopted the "zigzag" attention scheme from https://github.com/zhuzilin/ring-flash-attention/tree/main
|
For load-balancing, we adopted the "zigzag" dataloading scheme from ring-flash-attention.
|
||||||
For portable integration with more models, we don't follow the spirit of "block-wise FNN" in the original paper,
|
We also adopt the double ring topology from LoongTrain to fully utilize available
|
||||||
which requires fusing FFN with the Flash Attention kernel/function (see https://arxiv.org/pdf/2305.19370;
|
|
||||||
implemented in Jax and not optimized).
|
|
||||||
We adopt the double ring topology from LoongTrain (https://arxiv.org/pdf/2406.18485) to fully utilize available
|
|
||||||
NICs on each node, by computing attention within a inner ring first and then sending all KVs to the next
|
NICs on each node, by computing attention within a inner ring first and then sending all KVs to the next
|
||||||
ring at once.
|
ring at once.
|
||||||
|
Our implementation references code from
|
||||||
|
- ring-flash-attention: https://github.com/zhuzilin/ring-flash-attention/tree/main
|
||||||
|
- Megatron Context Parallel: https://github.com/NVIDIA/TransformerEngine/pull/726
|
||||||
|
References:
|
||||||
|
- Ring Attention with Blockwise Transformers for Near-Infinite Context
|
||||||
|
https://arxiv.org/abs/2310.01889
|
||||||
|
- LoongTrain: Efficient Training of Long-Sequence LLMs with Head-Context Parallelism
|
||||||
|
https://arxiv.org/abs/2406.18485
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Globle cache to avoid recomputation for same-lengthed sequences
|
# Globle cache to avoid recomputation for same-lengthed sequences
|
||||||
|
Loading…
Reference in New Issue
Block a user