mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[Feature] Split cross-entropy computation in SP (#5959)
* halfway * fix cross-PP-stage position id length diff bug * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * adapt chatglm, command-R, qwen * debug * halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * add comments * q1 index only once * remove events to simplify stream sync * simplify forward/backward logic * 2d ring forward passed * 2d ring backward passed * fixes * fix ring attn loss * 2D ring backward + llama passed * merge * update logger * fix typo * rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * remove typos * fixes * support GPT --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1097,13 +1097,19 @@ def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8
|
||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
|
||||
|
||||
|
||||
def gather_sp_output(hidden_states, sp_group, sp_mode, fp8_communication=False):
|
||||
def gather_sp_output(hidden_states, shard_config, sp_dim=1):
|
||||
"""
|
||||
Gather the output of the last layer for cross entropy computation
|
||||
"""
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
fp8_comm = shard_config.fp8_communication
|
||||
if dist.get_world_size(sp_group) == 1:
|
||||
return hidden_states
|
||||
|
||||
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
|
||||
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=scale, fp8_communication=fp8_communication
|
||||
hidden_states, sp_dim, sp_group, grad_scale=scale, fp8_communication=fp8_comm
|
||||
)
|
||||
return hidden_states
|
||||
|
Reference in New Issue
Block a user