mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[Feature] Zigzag Ring attention (#5905)
* halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import List
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -289,3 +289,199 @@ def create_randomizer_with_offset(
|
||||
Randomizer.increment_index()
|
||||
|
||||
return Randomizer(seed=base_seed)
|
||||
|
||||
|
||||
def split_batch_zigzag(
|
||||
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
|
||||
in the causal setting will result in the preceding ranks having much less workload.
|
||||
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
|
||||
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
|
||||
|
||||
Args:
|
||||
batch (List[torch.Tensor] or Tensor): The input tensor(s) to split.
|
||||
sp_group (ProcessGroup): The process group for sequence parallelism.
|
||||
seq_dim (int): The sequence dimension to split.
|
||||
is_label (bool): If True, mask and shift the tensor for next token prediction.
|
||||
|
||||
"""
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
sp_rank = dist.get_rank(sp_group)
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch = [batch]
|
||||
seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1
|
||||
|
||||
if sp_size > 1:
|
||||
for idx, tensor in enumerate(batch):
|
||||
assert (
|
||||
tensor.shape[seq_dim] // (sp_size * 2) > 1 and tensor.shape[seq_dim] % (sp_size * 2) == 0
|
||||
), f"Bro, the seq length {tensor.shape[seq_dim]} for tensor {idx} can't be split by {sp_size * 2}!"
|
||||
if is_label:
|
||||
assert tensor.dim() == 2, "Label shape should be (B, Seqlen)"
|
||||
tensor = torch.cat([tensor[:, 1:], torch.full_like(tensor[:, :1], -100)], dim=1)
|
||||
|
||||
tensor = tensor.view(
|
||||
*tensor.shape[:seq_dim],
|
||||
2 * sp_size,
|
||||
tensor.shape[seq_dim] // (2 * sp_size),
|
||||
*tensor.shape[seq_dim + 1 :],
|
||||
)
|
||||
indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device)
|
||||
tensor = tensor.index_select(seq_dim, indices).contiguous()
|
||||
# (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...)
|
||||
batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :])
|
||||
|
||||
if len(batch) == 1:
|
||||
return batch[0]
|
||||
return batch
|
||||
|
||||
|
||||
def split_varlen_zigzag(
|
||||
batch: Union[List[torch.Tensor], torch.Tensor],
|
||||
cu_seqlens: torch.Tensor,
|
||||
sp_group: ProcessGroup,
|
||||
max_seqlen: int = 0,
|
||||
is_2d: bool = False,
|
||||
is_label: bool = False,
|
||||
) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
"""Split each sequence in a batch of packed sequences in a zigzag fashion.
|
||||
For each tensor in batch, return packed sequences if is_2d is False;
|
||||
else return a padded batch of sequences.
|
||||
|
||||
Args:
|
||||
batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d.
|
||||
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
|
||||
sp_group (ProcessGroup): The process group for sequence parallelism.
|
||||
max_seqlen (int): The maximum sequence length in the batch before splitting.
|
||||
is_2d (bool): If True, then input has batch size and sequence length split into two dimensions.
|
||||
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
|
||||
|
||||
Returns:
|
||||
batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size)
|
||||
or (B, max_seqlen // sp_size, ...) if is_2d
|
||||
"""
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
sp_rank = dist.get_rank(sp_group)
|
||||
if is_2d:
|
||||
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
|
||||
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch = [batch]
|
||||
for i, packed_seq in enumerate(batch):
|
||||
device = packed_seq.device
|
||||
dtype = packed_seq.dtype
|
||||
|
||||
if is_2d:
|
||||
assert max_seqlen % (sp_size * 2) == 0
|
||||
# Recreate a padded tensor with the new max seqlen
|
||||
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
|
||||
local_seq = torch.zeros(shape, dtype=dtype, device=device)
|
||||
else:
|
||||
total_seqlen = cu_seqlens[-1]
|
||||
assert (
|
||||
total_seqlen % (2 * sp_size) == 0
|
||||
), f"total_seqlen {total_seqlen} must be divisible by 2 * sp_size = {2 * sp_size}"
|
||||
local_seq = []
|
||||
|
||||
for j in range(len(cu_seqlens) - 1):
|
||||
start, end = cu_seqlens[j], cu_seqlens[j + 1]
|
||||
seqlen = end - start
|
||||
assert (
|
||||
seqlen % (2 * sp_size) == 0
|
||||
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
|
||||
|
||||
if is_2d:
|
||||
seq = packed_seq[j][:seqlen]
|
||||
if is_label:
|
||||
# Shift one position to the right for next token prediction
|
||||
seq = torch.cat([seq[1:], torch.tensor([-100], dtype=dtype, device=device)])
|
||||
|
||||
seq = seq.chunk(2 * sp_size, dim=0)
|
||||
half = seqlen // sp_size // 2
|
||||
local_seq[j][:half] = seq[sp_rank]
|
||||
local_seq[j][half : seqlen // sp_size] = seq[2 * sp_size - 1 - sp_rank]
|
||||
else:
|
||||
seq = packed_seq[start:end]
|
||||
if is_label:
|
||||
seq = torch.cat(seq[1:], torch.tensor([-100], dtype=dtype, device=device))
|
||||
seq = seq.chunk(sp_size * 2)
|
||||
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
|
||||
|
||||
if is_2d:
|
||||
batch[i] = local_seq.contiguous()
|
||||
else:
|
||||
batch[i] = torch.cat(local_seq, dim=0)
|
||||
|
||||
if len(batch) == 1:
|
||||
batch = batch[0]
|
||||
return batch
|
||||
|
||||
|
||||
def is_share_sp_tp(sp_mode: str):
|
||||
"""sp_mode "ring" and "split_gather" use the TP group as SP group
|
||||
to split both the vocab and sequence, so we must gather the sequence
|
||||
to correctly get logits at each positions.
|
||||
"""
|
||||
return sp_mode in ["ring", "split_gather"]
|
||||
|
||||
|
||||
class RingComm:
|
||||
def __init__(self, process_group: dist.ProcessGroup):
|
||||
self._process_group = process_group
|
||||
self._ops = []
|
||||
self.rank = dist.get_rank(self._process_group)
|
||||
self.world_size = dist.get_world_size(self._process_group)
|
||||
self._reqs = []
|
||||
|
||||
self.send_rank = (self.rank + 1) % self.world_size
|
||||
self.recv_rank = (self.rank - 1) % self.world_size
|
||||
|
||||
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
|
||||
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
|
||||
|
||||
def send_recv(
|
||||
self,
|
||||
send_tensor: torch.Tensor,
|
||||
recv_tensor: Optional[torch.Tensor] = None,
|
||||
commit: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if recv_tensor is None:
|
||||
res = torch.empty_like(send_tensor)
|
||||
else:
|
||||
res = recv_tensor
|
||||
|
||||
# looks like batch_isend_irecv doesn't deadlock even
|
||||
# when we don't swap send recv ops based on rank
|
||||
send_op = dist.P2POp(dist.isend, send_tensor, self.send_rank, group=self._process_group)
|
||||
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
|
||||
self._ops.extend([send_op, recv_op])
|
||||
|
||||
if commit:
|
||||
self._reqs = dist.batch_isend_irecv(self._ops)
|
||||
return res
|
||||
|
||||
def commit(self):
|
||||
assert len(self._ops) > 0, "No ops to commit"
|
||||
self._reqs = dist.batch_isend_irecv(self._ops)
|
||||
|
||||
def wait(self):
|
||||
assert len(self._reqs) > 0, "No requests to wait for"
|
||||
for req in self._reqs:
|
||||
req.wait()
|
||||
self._reqs = []
|
||||
self._ops = []
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def get_half_index(cu_seqlens, *, front: bool):
|
||||
index = torch.zeros(cu_seqlens[-1], dtype=torch.bool, device=cu_seqlens.device)
|
||||
for i in range(len(cu_seqlens) - 1):
|
||||
start, end = cu_seqlens[i], cu_seqlens[i + 1]
|
||||
if front:
|
||||
end = (start + end) // 2
|
||||
else:
|
||||
start = (start + end) // 2
|
||||
index[start:end] = True
|
||||
return index
|
||||
|
Reference in New Issue
Block a user