[Feature] Zigzag Ring attention (#5905)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Edenzzzz
2024-08-16 13:56:38 +08:00
committed by GitHub
parent 887d2d579b
commit f5c84af0b0
50 changed files with 1870 additions and 326 deletions

View File

@@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import List
from typing import List, Optional, Union
import torch
import torch.distributed as dist
@@ -289,3 +289,199 @@ def create_randomizer_with_offset(
Randomizer.increment_index()
return Randomizer(seed=base_seed)
def split_batch_zigzag(
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
in the causal setting will result in the preceding ranks having much less workload.
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
Args:
batch (List[torch.Tensor] or Tensor): The input tensor(s) to split.
sp_group (ProcessGroup): The process group for sequence parallelism.
seq_dim (int): The sequence dimension to split.
is_label (bool): If True, mask and shift the tensor for next token prediction.
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
if isinstance(batch, torch.Tensor):
batch = [batch]
seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1
if sp_size > 1:
for idx, tensor in enumerate(batch):
assert (
tensor.shape[seq_dim] // (sp_size * 2) > 1 and tensor.shape[seq_dim] % (sp_size * 2) == 0
), f"Bro, the seq length {tensor.shape[seq_dim]} for tensor {idx} can't be split by {sp_size * 2}!"
if is_label:
assert tensor.dim() == 2, "Label shape should be (B, Seqlen)"
tensor = torch.cat([tensor[:, 1:], torch.full_like(tensor[:, :1], -100)], dim=1)
tensor = tensor.view(
*tensor.shape[:seq_dim],
2 * sp_size,
tensor.shape[seq_dim] // (2 * sp_size),
*tensor.shape[seq_dim + 1 :],
)
indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device)
tensor = tensor.index_select(seq_dim, indices).contiguous()
# (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...)
batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :])
if len(batch) == 1:
return batch[0]
return batch
def split_varlen_zigzag(
batch: Union[List[torch.Tensor], torch.Tensor],
cu_seqlens: torch.Tensor,
sp_group: ProcessGroup,
max_seqlen: int = 0,
is_2d: bool = False,
is_label: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""Split each sequence in a batch of packed sequences in a zigzag fashion.
For each tensor in batch, return packed sequences if is_2d is False;
else return a padded batch of sequences.
Args:
batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d.
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
sp_group (ProcessGroup): The process group for sequence parallelism.
max_seqlen (int): The maximum sequence length in the batch before splitting.
is_2d (bool): If True, then input has batch size and sequence length split into two dimensions.
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
Returns:
batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size)
or (B, max_seqlen // sp_size, ...) if is_2d
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
if is_2d:
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
if isinstance(batch, torch.Tensor):
batch = [batch]
for i, packed_seq in enumerate(batch):
device = packed_seq.device
dtype = packed_seq.dtype
if is_2d:
assert max_seqlen % (sp_size * 2) == 0
# Recreate a padded tensor with the new max seqlen
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
local_seq = torch.zeros(shape, dtype=dtype, device=device)
else:
total_seqlen = cu_seqlens[-1]
assert (
total_seqlen % (2 * sp_size) == 0
), f"total_seqlen {total_seqlen} must be divisible by 2 * sp_size = {2 * sp_size}"
local_seq = []
for j in range(len(cu_seqlens) - 1):
start, end = cu_seqlens[j], cu_seqlens[j + 1]
seqlen = end - start
assert (
seqlen % (2 * sp_size) == 0
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
if is_2d:
seq = packed_seq[j][:seqlen]
if is_label:
# Shift one position to the right for next token prediction
seq = torch.cat([seq[1:], torch.tensor([-100], dtype=dtype, device=device)])
seq = seq.chunk(2 * sp_size, dim=0)
half = seqlen // sp_size // 2
local_seq[j][:half] = seq[sp_rank]
local_seq[j][half : seqlen // sp_size] = seq[2 * sp_size - 1 - sp_rank]
else:
seq = packed_seq[start:end]
if is_label:
seq = torch.cat(seq[1:], torch.tensor([-100], dtype=dtype, device=device))
seq = seq.chunk(sp_size * 2)
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
if is_2d:
batch[i] = local_seq.contiguous()
else:
batch[i] = torch.cat(local_seq, dim=0)
if len(batch) == 1:
batch = batch[0]
return batch
def is_share_sp_tp(sp_mode: str):
"""sp_mode "ring" and "split_gather" use the TP group as SP group
to split both the vocab and sequence, so we must gather the sequence
to correctly get logits at each positions.
"""
return sp_mode in ["ring", "split_gather"]
class RingComm:
def __init__(self, process_group: dist.ProcessGroup):
self._process_group = process_group
self._ops = []
self.rank = dist.get_rank(self._process_group)
self.world_size = dist.get_world_size(self._process_group)
self._reqs = []
self.send_rank = (self.rank + 1) % self.world_size
self.recv_rank = (self.rank - 1) % self.world_size
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
def send_recv(
self,
send_tensor: torch.Tensor,
recv_tensor: Optional[torch.Tensor] = None,
commit: bool = True,
) -> torch.Tensor:
if recv_tensor is None:
res = torch.empty_like(send_tensor)
else:
res = recv_tensor
# looks like batch_isend_irecv doesn't deadlock even
# when we don't swap send recv ops based on rank
send_op = dist.P2POp(dist.isend, send_tensor, self.send_rank, group=self._process_group)
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
self._ops.extend([send_op, recv_op])
if commit:
self._reqs = dist.batch_isend_irecv(self._ops)
return res
def commit(self):
assert len(self._ops) > 0, "No ops to commit"
self._reqs = dist.batch_isend_irecv(self._ops)
def wait(self):
assert len(self._reqs) > 0, "No requests to wait for"
for req in self._reqs:
req.wait()
self._reqs = []
self._ops = []
@torch.jit.script
def get_half_index(cu_seqlens, *, front: bool):
index = torch.zeros(cu_seqlens[-1], dtype=torch.bool, device=cu_seqlens.device)
for i in range(len(cu_seqlens) - 1):
start, end = cu_seqlens[i], cu_seqlens[i + 1]
if front:
end = (start + end) // 2
else:
start = (start + end) // 2
index[start:end] = True
return index