mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Ring Attention] Improve comments (#6085)
* improve comments * improve comments --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
@@ -295,8 +295,8 @@ def split_batch_zigzag(
|
||||
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
|
||||
in the causal setting will result in the preceding ranks having much less workload.
|
||||
Split the input sequence batch . Naively spliting the attention mask in the causal setting
|
||||
will result in the preceding ranks having much less workload.
|
||||
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
|
||||
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
|
||||
|
||||
@@ -346,40 +346,42 @@ def split_varlen_zigzag(
|
||||
cu_seqlens: torch.Tensor,
|
||||
sp_group: ProcessGroup,
|
||||
max_seqlen: int = 0,
|
||||
is_2d: bool = False,
|
||||
is_batched_seq: bool = False,
|
||||
is_label: bool = False,
|
||||
) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
"""Split each sequence in a batch of packed sequences in a zigzag fashion.
|
||||
For each tensor in batch, return packed sequences if is_2d is False;
|
||||
else return a padded batch of sequences.
|
||||
|
||||
"""Split a packed seq/batch of padded sequences in a Zigzag fashion.
|
||||
Different from split_batch_zigzag, inputs here have variable sequence lengths.
|
||||
Args:
|
||||
batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d.
|
||||
batch (List[torch.Tensor]): Packed sequences of shape (T, ...), or (B, Sq, ...) if is_batched_seq,
|
||||
where T is the total number of tokens.
|
||||
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
|
||||
sp_group (ProcessGroup): The process group for sequence parallelism.
|
||||
max_seqlen (int): The maximum sequence length in the batch before splitting.
|
||||
is_2d (bool): If True, then input has batch size and sequence length split into two dimensions.
|
||||
is_batched_seq (bool): If True, then the input is a batch of sequences padded to the same len.
|
||||
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
|
||||
|
||||
Returns:
|
||||
batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size)
|
||||
or (B, max_seqlen // sp_size, ...) if is_2d
|
||||
batch (List[torch.Tensor]): Packed sequences of shape (T, ..)
|
||||
or (B, max_seqlen // sp_size, ...) if is_batched_seq
|
||||
"""
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
sp_rank = dist.get_rank(sp_group)
|
||||
if sp_size == 1:
|
||||
return batch
|
||||
|
||||
if is_2d:
|
||||
if is_batched_seq:
|
||||
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
|
||||
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch = [batch]
|
||||
# seq: (B, Sq, h, n)
|
||||
# seq = seq[:, :rank * (seqlen // sp_size), ...]
|
||||
|
||||
for i, packed_seq in enumerate(batch):
|
||||
device = packed_seq.device
|
||||
dtype = packed_seq.dtype
|
||||
|
||||
if is_2d:
|
||||
if is_batched_seq:
|
||||
assert max_seqlen % (sp_size * 2) == 0
|
||||
# Recreate a padded tensor with the new max seqlen
|
||||
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
|
||||
@@ -398,7 +400,7 @@ def split_varlen_zigzag(
|
||||
seqlen % (2 * sp_size) == 0
|
||||
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
|
||||
|
||||
if is_2d:
|
||||
if is_batched_seq:
|
||||
seq = packed_seq[j][:seqlen]
|
||||
if is_label:
|
||||
# Shift one position to the right for next token prediction
|
||||
@@ -415,7 +417,7 @@ def split_varlen_zigzag(
|
||||
seq = seq.chunk(sp_size * 2)
|
||||
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
|
||||
|
||||
if is_2d:
|
||||
if is_batched_seq:
|
||||
batch[i] = local_seq.contiguous()
|
||||
else:
|
||||
batch[i] = torch.cat(local_seq, dim=0)
|
||||
|
Reference in New Issue
Block a user