mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[Feature] Zigzag Ring attention (#5905)
* halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,10 @@ from enum import Enum
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from colossalai.kernel.kernel_loader import (
|
||||
FlashAttentionForFloatAndCustomMaskLoader,
|
||||
@@ -10,12 +13,18 @@ from colossalai.kernel.kernel_loader import (
|
||||
FlashAttentionWithCustomMaskLoader,
|
||||
KernelLoader,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .utils import RingComm, get_half_index, split_varlen_zigzag
|
||||
|
||||
__all__ = [
|
||||
"AttnMaskType",
|
||||
"ColoAttention",
|
||||
]
|
||||
|
||||
_flash_attn_forward = _flash_attn_backward = None
|
||||
_unpad_input = _pad_input = None
|
||||
|
||||
|
||||
class AttnMaskType(Enum):
|
||||
CUSTOM = 0
|
||||
@@ -38,20 +47,32 @@ def invert_mask(mask: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
|
||||
def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]:
|
||||
def get_pad_info(
|
||||
padding_mask: torch.Tensor, invert: Optional[bool] = False, return_indices: Optional[bool] = True
|
||||
) -> Tuple[int, torch.Tensor, torch.Tensor]:
|
||||
"""Get padding information from padding mask.
|
||||
|
||||
Args:
|
||||
padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S]
|
||||
padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, Skv]
|
||||
invert (Optional[bool], optional): Whether to reverse the padding mask.
|
||||
return_indices (Optional[bool], optional): Whether to return the indices of non-masked tokens.
|
||||
|
||||
Returns:
|
||||
Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices)
|
||||
max_seqlen_in_batch (int): Maximum sequence length in the batch.
|
||||
cu_seqlens (torch.Tensor): Shape [B+1]. Cumulative sequence lengths of the sequences in the batch.
|
||||
indices (torch.Tensor): Shape [total_nonzero]. The indices of non-masked tokens from the flattened input sequence.
|
||||
"""
|
||||
if invert:
|
||||
padding_mask = padding_mask.logical_not()
|
||||
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
||||
if return_indices:
|
||||
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
||||
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return max_seqlen_in_batch, cu_seqlens, indices
|
||||
if return_indices:
|
||||
return max_seqlen_in_batch, cu_seqlens, indices
|
||||
return max_seqlen_in_batch, cu_seqlens
|
||||
|
||||
|
||||
class ColoAttention:
|
||||
@@ -107,6 +128,7 @@ class ColoAttention:
|
||||
q_padding_mask: Optional[torch.Tensor] = None,
|
||||
kv_padding_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
invert: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Return a dictionary of keyword arguments for attention function. It supports 4 mask type.
|
||||
1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves.
|
||||
@@ -124,7 +146,7 @@ class ColoAttention:
|
||||
The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token.
|
||||
If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None.
|
||||
is_causal (bool, optional): Whether to use causal attention mask. Defaults to False.
|
||||
|
||||
invert_mask (bool, optional): Whether to invert the mask. Defaults to True.
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function.
|
||||
"""
|
||||
@@ -154,7 +176,7 @@ class ColoAttention:
|
||||
assert kv_padding_mask.shape == (
|
||||
b,
|
||||
s_kv,
|
||||
), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
|
||||
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
|
||||
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
||||
outputs.update(
|
||||
{
|
||||
@@ -172,7 +194,8 @@ class ColoAttention:
|
||||
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
||||
else:
|
||||
outputs["attention_mask_type"] = AttnMaskType.PADDED
|
||||
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
||||
if invert:
|
||||
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
||||
outputs["attention_mask"] = attention_mask
|
||||
return outputs
|
||||
|
||||
@@ -191,6 +214,7 @@ class ColoAttention:
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Flash Attention function. It supports 4 mask type.
|
||||
1. custom mask: recv attention_mask
|
||||
@@ -199,9 +223,9 @@ class ColoAttention:
|
||||
4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D]
|
||||
k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D]
|
||||
v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D]
|
||||
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
|
||||
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Skv, D]
|
||||
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Skv, D]
|
||||
attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None.
|
||||
attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM.
|
||||
cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths
|
||||
@@ -218,7 +242,7 @@ class ColoAttention:
|
||||
scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor. Shape should be [B, N, Sq, D]
|
||||
torch.Tensor: Output tensor. Shape should be [B, nHeads, Sq, D]
|
||||
"""
|
||||
# known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan
|
||||
# this case is usaul when padding mask is used and self attention is performed
|
||||
@@ -252,6 +276,7 @@ class ColoAttention:
|
||||
else:
|
||||
# if attention_mask is None, attention_mask_type should be the default value
|
||||
assert attention_mask_type == AttnMaskType.CUSTOM
|
||||
|
||||
# kernel dispatch
|
||||
mask_type = attention_mask_type if attention_mask is not None else None
|
||||
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
|
||||
@@ -274,3 +299,858 @@ class ColoAttention:
|
||||
q_indices=q_indices,
|
||||
kv_indices=kv_indices,
|
||||
)
|
||||
|
||||
|
||||
def _load_varlen_helpers():
|
||||
"""Helper to load functions for padding and unpadding packed sequences.
|
||||
Use only when flash attn is installed
|
||||
"""
|
||||
global _pad_input, _unpad_input
|
||||
# Flash attn claims this is more efficient than torch's bool indexing due to avoiding
|
||||
# broadcast
|
||||
if _pad_input is None or _unpad_input is None:
|
||||
try:
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input
|
||||
|
||||
def unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor):
|
||||
return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices)
|
||||
|
||||
_pad_input = pad_input
|
||||
_unpad_input = unpad_input
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'"
|
||||
) from e
|
||||
|
||||
|
||||
def _load_flash_attn():
|
||||
"""A light-weight loader to check whether flash-attn is installed.
|
||||
Can't use ColoAttention._dispatch_kernel because we mutate the backward pass
|
||||
"""
|
||||
global _flash_attn_forward, _flash_attn_backward
|
||||
if _flash_attn_forward is None or _flash_attn_backward is None:
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
|
||||
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'"
|
||||
) from e
|
||||
|
||||
_load_varlen_helpers()
|
||||
|
||||
|
||||
# NOTE: This can cause spawned processes to hang on exit
|
||||
# with python 3.9
|
||||
@torch.compile()
|
||||
def _rescale_out_lse(out, block_out, lse, block_lse):
|
||||
"""
|
||||
Compute the new attention denominator:
|
||||
exp(lse) + exp(block_lse) = exp(max_scale) * (exp(min_scale - max_scale) + 1)
|
||||
Args:
|
||||
out: (T, H, D)
|
||||
block_out: (T, H, D)
|
||||
lse: (H, T, 1)
|
||||
block_lse: (H, T, 1)
|
||||
"""
|
||||
|
||||
# min_scale = torch.min(lse, block_lse)
|
||||
# max_scale = torch.max(lse, block_lse)
|
||||
# new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
|
||||
|
||||
# NOTE: directly assigning to .data here is buggy
|
||||
# probably due to casting dtypes/strides
|
||||
new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
|
||||
|
||||
new_block_lse = torch.exp(block_lse - new_lse)
|
||||
out = (torch.exp(lse - new_lse) * out + new_block_lse * block_out).to(out)
|
||||
lse = new_lse
|
||||
|
||||
# Equivalent to the above
|
||||
# See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
|
||||
# out = (out - F.sigmoid(block_lse - lse) * (out - block_out))
|
||||
# lse = (lse - F.logsigmoid(lse - block_lse))
|
||||
return out, lse
|
||||
|
||||
|
||||
class RingAttention(torch.autograd.Function):
|
||||
"""Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context`
|
||||
(https://arxiv.org/abs/2310.01889).
|
||||
For load-balancing we adopted the "zigzag" attention scheme from https://github.com/zhuzilin/ring-flash-attention/tree/main
|
||||
For portable integration with more models, we don't follow the spirit of "block-wise FNN" in the original paper,
|
||||
which requires fusing FFN with the Flash Attention kernel/function (see https://arxiv.org/pdf/2305.19370;
|
||||
implemented in Jax and not optimized).
|
||||
We adopt the double ring topology from LoongTrain (https://arxiv.org/pdf/2406.18485) to fully utilize available
|
||||
NICs on each node, by computing attention within a inner ring first and then sending all KVs to the next
|
||||
ring at once.
|
||||
"""
|
||||
|
||||
# Globle cache to avoid recomputation for same-lengthed sequences
|
||||
CU_SEQLENS: torch.Tensor = None # [B+1]
|
||||
TOTAL_SEQLEN: int = None
|
||||
HALF_INDICES: Tuple = None
|
||||
SUPPORTED_MASK_TYPES = (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL)
|
||||
ATTN_DONE: torch.cuda.Event = None
|
||||
SP_STREAM: torch.cuda.Stream = None
|
||||
SP_GROUP: dist.ProcessGroup = None
|
||||
# duplicate process group for concurrent NCCL streams
|
||||
# while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
|
||||
# against this, in practice it seems to work fine.
|
||||
INNER_RING_GROUP: dist.ProcessGroup = None
|
||||
INNER_RING_GROUP_COPY: dist.ProcessGroup = None
|
||||
INTER_RING_GROUP: dist.ProcessGroup = None
|
||||
INTER_RING_GROUP_COPY: dist.ProcessGroup = None
|
||||
|
||||
@staticmethod
|
||||
def get_double_ring_groups(sp_group, inner_ring_size=None):
|
||||
"""
|
||||
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
|
||||
shouldn't be larger than the number of NICs on each node.
|
||||
Args:
|
||||
sp_group (dist.ProcessGroup): Process group for sequence parallelism
|
||||
inner_ring_size (Optional[int], optional): Inner ring size. Defaults to None.
|
||||
Returns:
|
||||
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
|
||||
"""
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
sp_rank = dist.get_rank(sp_group)
|
||||
|
||||
if inner_ring_size is None:
|
||||
if torch.cuda.device_count() >= dist.get_world_size():
|
||||
# single node, no need to consider NICs
|
||||
return sp_group, sp_group
|
||||
if sp_size <= 4:
|
||||
inner_ring_size = min(2, sp_size)
|
||||
else:
|
||||
inner_ring_size = min(4, sp_size)
|
||||
else:
|
||||
assert (
|
||||
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
|
||||
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
|
||||
|
||||
if inner_ring_size == sp_size:
|
||||
return sp_group, sp_group
|
||||
assert (
|
||||
sp_size % inner_ring_size == 0
|
||||
), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.info(
|
||||
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
|
||||
ranks=[0],
|
||||
)
|
||||
num_rings = sp_size // inner_ring_size
|
||||
inner_ring_group = None
|
||||
inter_ring_group = None
|
||||
|
||||
# Create inner ring groups
|
||||
for i in range(inner_ring_size):
|
||||
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
|
||||
group = dist.new_group(ranks)
|
||||
if sp_rank in ranks:
|
||||
inner_ring_group = group
|
||||
|
||||
# Create inter ring groups
|
||||
for i in range(num_rings):
|
||||
ranks = list(range(i, sp_size, num_rings))
|
||||
group = dist.new_group(ranks)
|
||||
if sp_rank in ranks:
|
||||
inter_ring_group = group
|
||||
|
||||
return inner_ring_group, inter_ring_group
|
||||
|
||||
@staticmethod
|
||||
def attention(
|
||||
q, # (B, H, Sq, D)
|
||||
k,
|
||||
v,
|
||||
sp_group,
|
||||
attention_mask_type,
|
||||
cu_seqlens=None,
|
||||
max_seqlen=None,
|
||||
valid_indices=None,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
deterministic=False,
|
||||
return_softmax=False,
|
||||
inner_ring_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Ring Attention forward pass supporting variable-length sequences. When using varlen mode,
|
||||
each sequence in the batch should have length divisible by sp_size * 2.
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
|
||||
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D]
|
||||
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D]
|
||||
sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism
|
||||
sp_tream (torch.cuda.Stream): An different stream for output correction.
|
||||
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
Shape should be [B+1].
|
||||
max_seqlen (Optional[int], optional): Maximum query sequence length in the batch.
|
||||
valid_indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from get_pad_info.
|
||||
Shape should be [t].
|
||||
dropout_p (float, optional): Dropout probability. Defaults to 0.0.
|
||||
softmax_scale (Optional[float], optional): Scaling factor applied prior to softmax.
|
||||
deterministic (bool, optional): Whether to force deterministic backward pass. See https://github.com/Dao-AILab/flash-attention/issues/349
|
||||
return_softmax (bool, optional): Whether to return the softmax denominator (logsumexp).
|
||||
inner_ring_size (Optional[int], optional): Inner ring size of the 2D ring. By default use a heuristic to decide.
|
||||
|
||||
Returns:
|
||||
out: Output tensor of shape [B, nHeads, Sq, D] or [T, nHeads, D] if pad_output is False.
|
||||
softmax_lse: (if return_softmax is True) Softmax denominator (logsumexp).
|
||||
Shape should be [total_q_seqlen, nHeads]
|
||||
"""
|
||||
# Check input args
|
||||
_load_flash_attn()
|
||||
if RingAttention.ATTN_DONE is None:
|
||||
RingAttention.ATTN_DONE = torch.cuda.Event()
|
||||
if RingAttention.SP_STREAM is None:
|
||||
RingAttention.SP_STREAM = torch.cuda.Stream()
|
||||
|
||||
assert (
|
||||
q.shape[2] == k.shape[2]
|
||||
), "Q, K and V having different sequence lengths (inference or cross-attn)\
|
||||
is not supported yet in training."
|
||||
assert (
|
||||
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
|
||||
), f"Mask type {attention_mask_type} is not supported yet."
|
||||
|
||||
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
|
||||
|
||||
if RingAttention.SP_GROUP is not sp_group:
|
||||
RingAttention.SP_GROUP = sp_group
|
||||
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
|
||||
RingAttention.INNER_RING_GROUP = inner_ring_group
|
||||
RingAttention.INTER_RING_GROUP = inter_ring_group
|
||||
else:
|
||||
inner_ring_group = RingAttention.INNER_RING_GROUP
|
||||
inter_ring_group = RingAttention.INTER_RING_GROUP
|
||||
|
||||
# (B, H, Sq, D) -> (B, Sq, H, D)
|
||||
q, k, v = [x.transpose(1, 2).contiguous() for x in (q, k, v)]
|
||||
pad_output = q.dim() == 4
|
||||
|
||||
# Get sequence length info for varlen forward
|
||||
if attention_mask_type == AttnMaskType.CAUSAL:
|
||||
# All sequences share the same length
|
||||
b, sq, h, d = q.shape
|
||||
max_seqlen = sq
|
||||
# Cache to avoid recreation for a single sequence
|
||||
if sq * b == RingAttention.TOTAL_SEQLEN:
|
||||
cu_seqlens = RingAttention.CU_SEQLENS
|
||||
else:
|
||||
cu_seqlens = torch.arange(0, b * sq + 1, sq, device=q.device, dtype=torch.int32)
|
||||
RingAttention.TOTAL_SEQLEN = b * sq
|
||||
|
||||
# "Packed" mode where sequences of different lengths are packed into [total_q_seqlen, H, D]
|
||||
elif attention_mask_type == AttnMaskType.PADDED_CAUSAL:
|
||||
assert (
|
||||
cu_seqlens is not None and max_seqlen is not None and valid_indices is not None
|
||||
), "Packed mode requires pre-computed cu_seqlens and max_seq_len."
|
||||
if pad_output:
|
||||
b, sq, h, d = q.shape
|
||||
q, k, v = [_unpad_input(x, valid_indices) for x in (q, k, v)]
|
||||
|
||||
out, softmax_lse = RingAttention.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sp_group,
|
||||
RingAttention.SP_STREAM,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
deterministic,
|
||||
return_softmax,
|
||||
attention_mask_type == AttnMaskType.PADDED_CAUSAL,
|
||||
inner_ring_group,
|
||||
inter_ring_group,
|
||||
)
|
||||
|
||||
if attention_mask_type == AttnMaskType.PADDED_CAUSAL:
|
||||
if pad_output:
|
||||
out = _pad_input(out, valid_indices, b, sq) # (T, ...) -> (B, Sq, ...)
|
||||
out = out.transpose(1, 2) # (B, Sq, H, D) -> (B, H, Sq, D)
|
||||
else:
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
if return_softmax:
|
||||
return out, softmax_lse
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
sp_group: dist.ProcessGroup,
|
||||
sp_stream: torch.cuda.Stream,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: int,
|
||||
dropout_p: float = 0.0,
|
||||
softmax_scale: Optional[float] = None,
|
||||
deterministic: Optional[bool] = False,
|
||||
return_softmax: Optional[bool] = False,
|
||||
is_packed: Optional[bool] = False,
|
||||
inner_ring_group: Optional[dist.ProcessGroup] = None,
|
||||
inter_ring_group: Optional[dist.ProcessGroup] = None,
|
||||
):
|
||||
|
||||
cu_seqlens_q = cu_seqlens_kv = cu_seqlens
|
||||
max_seqlen_q = max_seqlen_kv = max_seqlen
|
||||
cu_seqlens_half = cu_seqlens // 2
|
||||
max_seqlen_half = max_seqlen // 2
|
||||
|
||||
misc_kwargs = {
|
||||
"window_size": (-1, -1),
|
||||
"alibi_slopes": None,
|
||||
"softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale,
|
||||
"dropout_p": dropout_p,
|
||||
"block_table": None,
|
||||
"softcap": 0.0,
|
||||
"return_softmax": False,
|
||||
}
|
||||
|
||||
if (
|
||||
RingAttention.HALF_INDICES is not None
|
||||
and cu_seqlens.shape == RingAttention.CU_SEQLENS.shape
|
||||
and (cu_seqlens == RingAttention.CU_SEQLENS).all()
|
||||
):
|
||||
half_idx_front, half_idx_back = RingAttention.HALF_INDICES
|
||||
else:
|
||||
half_idx_front = get_half_index(cu_seqlens, front=True)
|
||||
half_idx_back = get_half_index(cu_seqlens, front=False)
|
||||
RingAttention.HALF_INDICES = (half_idx_front, half_idx_back)
|
||||
RingAttention.CU_SEQLENS = cu_seqlens
|
||||
|
||||
if is_packed:
|
||||
t, h, d = q.shape
|
||||
else:
|
||||
b, sq, h, d = q.shape
|
||||
t = b * sq
|
||||
# Be careful about GQA/MQA in reshape
|
||||
q, k, v = [x.view(t, *x.shape[-2:]) for x in (q, k, v)]
|
||||
|
||||
if inner_ring_group is None or inter_ring_group is None:
|
||||
# Use one ring if not specified
|
||||
inner_ring_group = inter_ring_group = sp_group
|
||||
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
sp_rank = dist.get_rank(sp_group)
|
||||
# Attempt to achieve concurrent comm in the two-stream forward
|
||||
local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)]
|
||||
inter_ring_comm = RingComm(inter_ring_group)
|
||||
local_sp_size = dist.get_world_size(inner_ring_group)
|
||||
local_sp_rank = dist.get_rank(inner_ring_group)
|
||||
inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0
|
||||
num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1
|
||||
|
||||
# Non-contiguous indexing copies to a new contiguous tensor,
|
||||
# so only do it once
|
||||
if sp_rank != sp_size - 1:
|
||||
q1 = q[half_idx_back]
|
||||
|
||||
# Pre-allocate double buffer for overlapping and receiving next step's inputs
|
||||
kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D)
|
||||
kv_buffers.append(torch.empty_like(kv_buffers[0]))
|
||||
|
||||
# outputs
|
||||
out = None
|
||||
block_out = [None, None]
|
||||
softmax_lse = [None, None]
|
||||
block_softmax_lse = [None, None] # log sum exp, the denominator of softmax in attention
|
||||
rng_states = [None for _ in range(sp_size)]
|
||||
sp_streams = [torch.cuda.current_stream(), sp_stream]
|
||||
|
||||
def _forward(q, k, v, causal):
|
||||
(
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
out,
|
||||
softmax_lse,
|
||||
_,
|
||||
rng_state,
|
||||
) = _flash_attn_forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
|
||||
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
|
||||
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
|
||||
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
|
||||
causal=causal,
|
||||
**misc_kwargs,
|
||||
)
|
||||
return out, softmax_lse, rng_state
|
||||
|
||||
def _local_ring_forward():
|
||||
# (Hopefully) overlap output correction with next flash attn
|
||||
for i in range(local_sp_size):
|
||||
with torch.cuda.stream(sp_streams[i % 2]):
|
||||
# Wait for current kv from prev rank
|
||||
# NOTE: waiting outside the current stream will NOT correctly synchronize.
|
||||
if i > 0:
|
||||
local_kv_comms[(i + 1) % 2].wait()
|
||||
|
||||
# Avoid overwriting attn input when it shares mem with buffer
|
||||
if not RingAttention.ATTN_DONE.query():
|
||||
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
|
||||
if i < local_sp_size - 1:
|
||||
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
||||
|
||||
if i == 0:
|
||||
# Compute with local KV; no mask
|
||||
kv_block = kv_buffers[0]
|
||||
q_block = q
|
||||
(block_out[i % 2], block_softmax_lse[i % 2], rng_states[i]) = _forward( # (T, H, D) # (H, T)
|
||||
q_block, kv_block[0], kv_block[1], causal=True
|
||||
)
|
||||
elif i <= local_sp_rank:
|
||||
# Received the "surrounding" kv chunks
|
||||
# Drop the second half of received kv
|
||||
# (2, t // 2, H, D)
|
||||
kv_block = kv_buffers[i % 2][:, half_idx_front]
|
||||
q_block = q
|
||||
(
|
||||
block_out[i % 2], # (T, H, D)
|
||||
block_softmax_lse[i % 2], # (H, T)
|
||||
rng_states[i],
|
||||
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
|
||||
else:
|
||||
# Received the inner kv chunks
|
||||
# Drop the first half of q
|
||||
kv_block = kv_buffers[i % 2]
|
||||
q_block = q1
|
||||
(
|
||||
block_out[i % 2], # (T, H, D)
|
||||
block_softmax_lse[i % 2], # (H, T)
|
||||
rng_states[i],
|
||||
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
|
||||
RingAttention.ATTN_DONE.record()
|
||||
|
||||
block_softmax_lse[i % 2] = (
|
||||
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
||||
) # (H, T) -> (T, H, 1)
|
||||
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
|
||||
# Output and log sum exp correction. Ideally overlap this with the next flash attn kernel.
|
||||
# In reality this always finishes before next flash attn; no need for extra sync.
|
||||
if i == 0:
|
||||
out = block_out[0]
|
||||
softmax_lse = block_softmax_lse[0]
|
||||
elif i <= local_sp_rank:
|
||||
out, softmax_lse = _rescale_out_lse(
|
||||
out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2]
|
||||
)
|
||||
else:
|
||||
out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse(
|
||||
out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2]
|
||||
)
|
||||
|
||||
torch.cuda.current_stream().wait_stream(sp_stream)
|
||||
return out, softmax_lse
|
||||
|
||||
def _other_ring_forward(ring_num_idx, out, softmax_lse):
|
||||
# Loop through the inner ring after receiving
|
||||
# all new KVs from the previous inner ring
|
||||
for i in range(local_sp_size):
|
||||
with torch.cuda.stream(sp_streams[i % 2]):
|
||||
if not RingAttention.ATTN_DONE.query():
|
||||
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
|
||||
if i < local_sp_size - 1:
|
||||
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
||||
|
||||
# Send & recv KV
|
||||
if i > 0:
|
||||
local_kv_comms[(i + 1) % 2].wait()
|
||||
|
||||
if ring_num_idx > inter_ring_rank:
|
||||
kv_block = kv_buffers[i % 2]
|
||||
(
|
||||
block_out[i % 2],
|
||||
block_softmax_lse[i % 2],
|
||||
rng_states[i + local_sp_size * ring_num_idx],
|
||||
) = _forward(q1, kv_block[0], kv_block[1], causal=False)
|
||||
RingAttention.ATTN_DONE.record()
|
||||
block_softmax_lse[i % 2] = (
|
||||
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
||||
)
|
||||
out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse(
|
||||
out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2]
|
||||
)
|
||||
else:
|
||||
kv_block = kv_buffers[i % 2][:, half_idx_front]
|
||||
(
|
||||
block_out[i % 2],
|
||||
block_softmax_lse[i % 2],
|
||||
rng_states[i + local_sp_size * ring_num_idx],
|
||||
) = _forward(q, kv_block[0], kv_block[1], causal=False)
|
||||
RingAttention.ATTN_DONE.record()
|
||||
block_softmax_lse[i % 2] = (
|
||||
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
||||
)
|
||||
out, softmax_lse = _rescale_out_lse(
|
||||
out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2]
|
||||
)
|
||||
|
||||
torch.cuda.current_stream().wait_stream(sp_stream)
|
||||
return out, softmax_lse
|
||||
|
||||
# Send and recv KV between rings at once to maximize NIC util.
|
||||
inter_ring_kv = None
|
||||
for ring_num_idx in range(num_rings):
|
||||
if ring_num_idx > 0:
|
||||
inter_ring_comm.wait()
|
||||
# Reset indices
|
||||
kv_buffers[0] = inter_ring_kv
|
||||
|
||||
if ring_num_idx < num_rings - 1:
|
||||
if ring_num_idx == 0:
|
||||
to_send = kv_buffers[0]
|
||||
else:
|
||||
# The last received KV
|
||||
to_send = kv_buffers[(local_sp_size - 1) % 2]
|
||||
inter_ring_kv = inter_ring_comm.send_recv(to_send)
|
||||
|
||||
if ring_num_idx == 0:
|
||||
out, softmax_lse = _local_ring_forward()
|
||||
else:
|
||||
out, softmax_lse = _other_ring_forward(ring_num_idx, out, softmax_lse)
|
||||
|
||||
out = out.to(q.dtype)
|
||||
if not is_packed:
|
||||
out = out.view(b, sq, h, d)
|
||||
q, k, v = [x.view(b, sq, *x.shape[-2:]) for x in (q, k, v)] # (T, H, D) -> (B, Sq, H, D)
|
||||
softmax_lse = softmax_lse.squeeze(-1)
|
||||
|
||||
ctx.sp_group = sp_group
|
||||
ctx.max_seqlen_q = ctx.max_seqlen_kv = max_seqlen
|
||||
misc_kwargs["deterministic"] = deterministic
|
||||
del misc_kwargs["return_softmax"]
|
||||
ctx.misc_kwargs = misc_kwargs
|
||||
ctx.is_packed = is_packed
|
||||
|
||||
ctx.kv_group = inner_ring_group
|
||||
ctx.inter_kv_group = inter_ring_group
|
||||
|
||||
ctx.save_for_backward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse.transpose(0, 1).contiguous(), # (T, H) -> (H, T)
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_kv,
|
||||
half_idx_front,
|
||||
half_idx_back,
|
||||
*rng_states,
|
||||
)
|
||||
|
||||
if return_softmax:
|
||||
return out, softmax_lse
|
||||
return out, None
|
||||
|
||||
def backward(ctx, dout, _):
|
||||
"""
|
||||
During backward, we accumulate q grads on each rank locally, but iterate kv and their grads
|
||||
over all ranks for accumulation.
|
||||
"""
|
||||
(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9]
|
||||
rng_states = ctx.saved_tensors[9:]
|
||||
|
||||
is_packed = ctx.is_packed
|
||||
max_seqlen_q = ctx.max_seqlen_q
|
||||
max_seqlen_kv = ctx.max_seqlen_kv
|
||||
cu_seqlens_half = cu_seqlens_q // 2
|
||||
max_seqlen_half = max_seqlen_q // 2
|
||||
misc_kwargs = ctx.misc_kwargs
|
||||
del misc_kwargs["block_table"]
|
||||
|
||||
assert (
|
||||
out.shape == dout.shape == q.shape
|
||||
), f"out {out.shape} and dout {dout.shape} should have the same shape ({q.shape})."
|
||||
|
||||
if is_packed:
|
||||
t, h, d = q.shape
|
||||
else:
|
||||
b, sq, h, d = q.shape
|
||||
t = b * sq
|
||||
q, k, v, out, dout = [x.view(t, *x.shape[-2:]) for x in (q, k, v, out, dout)]
|
||||
|
||||
# Sequence parallel args
|
||||
sp_group = ctx.sp_group
|
||||
local_kv_group = ctx.kv_group
|
||||
inter_kv_group = ctx.inter_kv_group
|
||||
|
||||
local_sp_rank = dist.get_rank(sp_group)
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
# Using separate streams (pg) for concurrent kv and dkv comm may
|
||||
# cause NCCL "software caused connection abort" here...
|
||||
local_kv_comm = RingComm(local_kv_group)
|
||||
local_dkv_comm = RingComm(local_kv_group)
|
||||
inter_kv_comm = RingComm(inter_kv_group)
|
||||
inter_dkv_comm = RingComm(inter_kv_group)
|
||||
local_sp_size = dist.get_world_size(local_kv_group)
|
||||
local_sp_rank = dist.get_rank(local_kv_group)
|
||||
|
||||
if dist.get_world_size(inter_kv_group) != sp_size:
|
||||
num_rings = dist.get_world_size(inter_kv_group)
|
||||
inter_ring_rank = dist.get_rank(inter_kv_group)
|
||||
else:
|
||||
num_rings = 1
|
||||
inter_ring_rank = 0
|
||||
|
||||
if local_sp_rank != sp_size - 1:
|
||||
softmax_lse1 = softmax_lse[:, half_idx_back]
|
||||
dout = dout.contiguous()
|
||||
|
||||
# Double comm buffers for sending and receiving kv
|
||||
kv_buffers = [torch.stack((k, v))] # (2, T, H, D)
|
||||
kv_buffers.append(torch.empty_like(kv_buffers[0]))
|
||||
|
||||
dq = None # (T, H, D)
|
||||
# Intermediate outputs
|
||||
dq_block = torch.empty_like(q) # (T, H, D)
|
||||
dk_block = torch.empty_like(k) # (T, H, D)
|
||||
dv_block = torch.empty_like(v) # (T, H, D)
|
||||
dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D)
|
||||
del k, v
|
||||
|
||||
def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
|
||||
_flash_attn_backward(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cu_seqlens_q if dq.shape[0] == t else cu_seqlens_half,
|
||||
cu_seqlens_kv if dk.shape[0] == t else cu_seqlens_half,
|
||||
max_seqlen_q if dq.shape[0] == t else max_seqlen_half,
|
||||
max_seqlen_kv if dk.shape[0] == t else max_seqlen_half,
|
||||
causal=causal,
|
||||
rng_state=rng_state,
|
||||
**misc_kwargs,
|
||||
)
|
||||
|
||||
# NOTE: We avoid using two streams due to doubled buffers
|
||||
# and that backward is more communication intensive.
|
||||
def _local_ring_backward():
|
||||
for i in range(local_sp_size):
|
||||
if i > 0:
|
||||
local_kv_comm.wait()
|
||||
|
||||
if i < local_sp_size - 1:
|
||||
# Send kv to next rank for backward
|
||||
local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
||||
|
||||
if i == 0:
|
||||
# Backward with local kv
|
||||
k_, v_ = kv_buffers[i % 2]
|
||||
q_, dout_, out_ = q, dout, out
|
||||
dq_, dk_, dv_ = dq_block, dk_block, dv_block
|
||||
_backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=True)
|
||||
|
||||
elif i <= local_sp_rank:
|
||||
# Drop the second half of kv
|
||||
# (T, H, D) -> (T // 2, H, D)
|
||||
k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]]
|
||||
dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)]
|
||||
dq_, q_, out_, dout_ = (dq_block, q, out, dout)
|
||||
_backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=False)
|
||||
|
||||
else:
|
||||
# Drop the first half of q
|
||||
k_, v_ = kv_buffers[i % 2]
|
||||
dk_, dv_ = dk_block, dv_block
|
||||
q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)]
|
||||
dq_ = dq_block[: t // 2]
|
||||
_backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_states[i], causal=False)
|
||||
|
||||
# Accumulate grads
|
||||
if i == 0:
|
||||
dq = dq_block.float()
|
||||
dkv_buffers[i % 2][0] = dk_block.float()
|
||||
dkv_buffers[i % 2][1] = dv_block.float()
|
||||
else:
|
||||
# Accumulate local dq
|
||||
if i <= local_sp_rank:
|
||||
dq += dq_ # (T, H, D)
|
||||
else:
|
||||
dq[half_idx_back] += dq_
|
||||
|
||||
# Wait for mobile kv grad accumulators
|
||||
local_dkv_comm.wait()
|
||||
|
||||
if i <= local_sp_rank:
|
||||
# q blocks "surrounded" by kv blocks
|
||||
dkv_buffers[i % 2][0][half_idx_front] += dk_
|
||||
dkv_buffers[i % 2][1][half_idx_front] += dv_
|
||||
else:
|
||||
# q blocks "surrounding" kv blocks
|
||||
dkv_buffers[i % 2][0] += dk_
|
||||
dkv_buffers[i % 2][1] += dv_
|
||||
local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2])
|
||||
|
||||
local_dkv_comm.wait()
|
||||
dkv_recv = dkv_buffers[local_sp_size % 2]
|
||||
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
|
||||
return dq, dkv_recv, dkv_send
|
||||
|
||||
def _other_ring_backward(ring_num_idx, dq):
|
||||
if ring_num_idx > inter_ring_rank:
|
||||
# Indexing is expensive
|
||||
q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)]
|
||||
else:
|
||||
q_, out_, dout_ = (q, out, dout)
|
||||
|
||||
for i in range(local_sp_size):
|
||||
if i > 0:
|
||||
local_kv_comm.wait()
|
||||
|
||||
if i < local_sp_size - 1:
|
||||
local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
||||
|
||||
rng_state = rng_states[i + local_sp_size * ring_num_idx]
|
||||
if ring_num_idx > inter_ring_rank:
|
||||
k_, v_ = kv_buffers[i % 2]
|
||||
dk_, dv_ = dk_block, dv_block
|
||||
dq_ = dq_block[: t // 2]
|
||||
_backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_state, causal=False)
|
||||
|
||||
dq[half_idx_back] += dq_
|
||||
if i > 0:
|
||||
local_dkv_comm.wait()
|
||||
else:
|
||||
inter_dkv_comm.wait()
|
||||
|
||||
dkv_buffers[i % 2][0] += dk_
|
||||
dkv_buffers[i % 2][1] += dv_
|
||||
else:
|
||||
k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]]
|
||||
dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)]
|
||||
dq_ = dq_block
|
||||
_backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_state, causal=False)
|
||||
|
||||
dq += dq_
|
||||
if i > 0:
|
||||
local_dkv_comm.wait()
|
||||
else:
|
||||
inter_dkv_comm.wait()
|
||||
|
||||
dkv_buffers[i % 2][0][half_idx_front] += dk_
|
||||
dkv_buffers[i % 2][1][half_idx_front] += dv_
|
||||
|
||||
local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2])
|
||||
|
||||
local_dkv_comm.wait()
|
||||
dkv_recv = dkv_buffers[local_sp_size % 2]
|
||||
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
|
||||
return dq, dkv_recv, dkv_send
|
||||
|
||||
inter_ring_kv = None
|
||||
for ring_num_idx in range(num_rings):
|
||||
if ring_num_idx > 0:
|
||||
inter_kv_comm.wait()
|
||||
kv_buffers[0] = inter_ring_kv
|
||||
|
||||
if ring_num_idx < num_rings - 1:
|
||||
# Re-allocate a buffer in each inter-ring step
|
||||
inter_ring_kv = inter_kv_comm.send_recv(kv_buffers[0])
|
||||
|
||||
if ring_num_idx == 0:
|
||||
dq, dkv_recv, dkv_send = _local_ring_backward()
|
||||
else:
|
||||
dq, dkv_recv, dkv_send = _other_ring_backward(ring_num_idx, dq)
|
||||
|
||||
if num_rings > 1:
|
||||
# Reuse the local buffers
|
||||
inter_dkv_comm.send_recv(send_tensor=dkv_recv, recv_tensor=dkv_send)
|
||||
# Reset indices
|
||||
dkv_buffers[0] = dkv_send
|
||||
dkv_buffers[1] = dkv_recv
|
||||
if ring_num_idx == num_rings - 1:
|
||||
inter_dkv_comm.wait()
|
||||
dkv_recv = dkv_buffers[0]
|
||||
|
||||
dq, dk, dv = [x.to(q.dtype) for x in (dq, *dkv_recv)]
|
||||
if not is_packed:
|
||||
dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)]
|
||||
|
||||
return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None)
|
||||
|
||||
@staticmethod
|
||||
def prepare_varlen_batch(
|
||||
attention_mask: torch.Tensor,
|
||||
sp_group: dist.ProcessGroup,
|
||||
inputs_embeds: torch.Tensor = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
is_label: bool = False,
|
||||
is_2d: bool = True,
|
||||
):
|
||||
"""
|
||||
Preprocess a batch of padded sequence by splitting input sequence by sp_size
|
||||
sequence-wise and packing them into one sequence. Updates the mask info accordingly.
|
||||
Args:
|
||||
attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
|
||||
sp_group (dist.ProcessGroup): Process group for sequence parallelism
|
||||
inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...]
|
||||
position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None.
|
||||
is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first
|
||||
token of each sequence.
|
||||
is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten
|
||||
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
|
||||
|
||||
Returns:
|
||||
inputs_embeds: Packed input embeddings of shape [B, Sq // sp_size, ...].
|
||||
mask_info: A dictionary of mask info.
|
||||
position_ids: Packed position ids of shape [..., Sq // sp_size].
|
||||
|
||||
"""
|
||||
_load_varlen_helpers()
|
||||
sp_size = dist.get_world_size(group=sp_group)
|
||||
sp_rank = dist.get_rank(group=sp_group)
|
||||
mask_info = {}
|
||||
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False)
|
||||
|
||||
# Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size)
|
||||
# Split mask to compute local nonzero position indices
|
||||
# (B, Sq) -> (B, max_seqlen // sp_size)
|
||||
attention_mask = attention_mask[:, : mask_info["max_seqlen"]]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]]
|
||||
inputs_embeds = split_varlen_zigzag(
|
||||
inputs_embeds,
|
||||
mask_info["cu_seqlens"],
|
||||
sp_group,
|
||||
mask_info["max_seqlen"],
|
||||
is_2d=is_2d,
|
||||
is_label=is_label,
|
||||
)
|
||||
attention_mask = split_varlen_zigzag(
|
||||
attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d
|
||||
)
|
||||
|
||||
if position_ids is not None:
|
||||
indices = torch.tensor([sp_rank, 2 * sp_size - sp_rank - 1], device=inputs_embeds.device)
|
||||
position_ids = (
|
||||
position_ids[..., : mask_info["max_seqlen"]] # unpad
|
||||
.view(-1, sp_size * 2, mask_info["max_seqlen"] // (sp_size * 2))
|
||||
.index_select(-2, indices)
|
||||
.view(-1, mask_info["max_seqlen"] // sp_size)
|
||||
)
|
||||
|
||||
mask_info["max_seqlen"] //= sp_size
|
||||
mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
mask_info["cu_seqlens"] //= sp_size
|
||||
mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
||||
return inputs_embeds, mask_info, position_ids
|
||||
|
Reference in New Issue
Block a user