mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 21:51:57 +00:00
[Feature] Split cross-entropy computation in SP (#5959)
* halfway * fix cross-PP-stage position id length diff bug * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * adapt chatglm, command-R, qwen * debug * halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * add comments * q1 index only once * remove events to simplify stream sync * simplify forward/backward logic * 2d ring forward passed * 2d ring backward passed * fixes * fix ring attn loss * 2D ring backward + llama passed * merge * update logger * fix typo * rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * remove typos * fixes * support GPT --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -433,7 +433,6 @@ class RingAttention(torch.autograd.Function):
|
||||
assert (
|
||||
sp_size % inner_ring_size == 0
|
||||
), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.info(
|
||||
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
|
||||
@@ -898,6 +897,7 @@ class RingAttention(torch.autograd.Function):
|
||||
|
||||
local_sp_rank = dist.get_rank(sp_group)
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
|
||||
# Using separate streams (pg) for concurrent kv and dkv comm may
|
||||
# cause NCCL "software caused connection abort" here...
|
||||
local_kv_comm = RingComm(local_kv_group)
|
||||
@@ -1119,9 +1119,14 @@ class RingAttention(torch.autograd.Function):
|
||||
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
|
||||
|
||||
Returns:
|
||||
inputs_embeds: Packed input embeddings of shape [B, Sq // sp_size, ...].
|
||||
mask_info: A dictionary of mask info.
|
||||
position_ids: Packed position ids of shape [..., Sq // sp_size].
|
||||
torch.Tensor:
|
||||
Packed input embeddings of shape [B, Sq // sp_size, ...].
|
||||
|
||||
Dict[str, Any]:
|
||||
A dictionary containing mask info.
|
||||
|
||||
torch.Tensor:
|
||||
Packed position ids of shape [..., Sq // sp_size].
|
||||
|
||||
"""
|
||||
_load_varlen_helpers()
|
||||
|
Reference in New Issue
Block a user