mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[Feature] Split cross-entropy computation in SP (#5959)
* halfway * fix cross-PP-stage position id length diff bug * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * adapt chatglm, command-R, qwen * debug * halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * add comments * q1 index only once * remove events to simplify stream sync * simplify forward/backward logic * 2d ring forward passed * 2d ring backward passed * fixes * fix ring attn loss * 2D ring backward + llama passed * merge * update logger * fix typo * rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * remove typos * fixes * support GPT --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -17,14 +17,13 @@ from transformers.models.cohere.modeling_cohere import (
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, dist_cross_entropy
|
||||
from ..layer._operation import gather_sp_output, is_share_sp_tp
|
||||
|
||||
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"]
|
||||
|
||||
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
||||
|
||||
@@ -52,6 +51,7 @@ class CommandPipelineForwards:
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
force_sp_output_gather: bool = True,
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -93,10 +93,16 @@ class CommandPipelineForwards:
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
# NOTE: For generating full positions ids
|
||||
# (the states will be gathered along the seq dim before attention fwd).
|
||||
if shard_config.sequence_parallelism_mode != "ring_attn" and not stage_manager.is_first_stage():
|
||||
seq_length *= shard_config.sequence_parallel_size
|
||||
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
|
||||
|
||||
seq_length_with_past = seq_length + past_seen_tokens
|
||||
|
||||
@@ -136,7 +142,7 @@ class CommandPipelineForwards:
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if stage_manager.is_first_stage() and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
@@ -208,23 +214,10 @@ class CommandPipelineForwards:
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
@@ -327,6 +320,7 @@ class CommandPipelineForwards:
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
@@ -335,9 +329,10 @@ class CommandPipelineForwards:
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits * self.logit_scale
|
||||
logits = logits.float()
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@@ -482,6 +477,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
force_sp_output_gather: bool = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -584,14 +580,10 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
# Cases that don't support parallelizing cross entropy computation along sequence
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
@@ -676,6 +668,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@@ -683,14 +676,16 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits * self.logit_scale
|
||||
logits = logits.float()
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.config.vocab_size,
|
||||
self.model.dtype,
|
||||
)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.model.dtype,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
Reference in New Issue
Block a user