[Feature] Split cross-entropy computation in SP (#5959)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* adapt chatglm, command-R, qwen

* debug

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* add comments

* q1 index only once

* remove events to simplify stream sync

* simplify forward/backward logic

* 2d ring forward passed

* 2d ring backward passed

* fixes

* fix ring attn loss

* 2D ring backward + llama passed

* merge

* update logger

* fix typo

* rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* remove typos

* fixes

* support GPT

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Wenxuan Tan
2024-09-10 12:06:50 +08:00
committed by GitHub
parent b3db1058ec
commit 8fd25d6e09
25 changed files with 527 additions and 1173 deletions

View File

@@ -32,14 +32,12 @@ except ImportError:
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, dist_cross_entropy
from ..layer._operation import gather_sp_output
from ..layer.utils import is_share_sp_tp
class Qwen2PipelineForwards:
@@ -64,6 +62,7 @@ class Qwen2PipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
force_sp_output_gather: bool = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
logger = logging.get_logger(__name__)
@@ -115,6 +114,14 @@ class Qwen2PipelineForwards:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
# Support SP + PP
sp_size = shard_config.sequence_parallel_size
sp_group = shard_config.sequence_parallel_process_group
sp_mode = shard_config.sequence_parallelism_mode
# For generating full positions ids (the states will be gathered along the seq dim before attention fwd).
if sp_mode != "ring_attn" and not stage_manager.is_first_stage():
seq_length *= sp_size
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
@@ -151,7 +158,6 @@ class Qwen2PipelineForwards:
elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
@@ -160,7 +166,6 @@ class Qwen2PipelineForwards:
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
@@ -169,22 +174,21 @@ class Qwen2PipelineForwards:
sliding_window=self.config.sliding_window,
)
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
if stage_manager.is_first_stage():
if shard_config.enable_sequence_parallelism:
if is_share_sp_tp(sp_mode):
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=sp_group,
)
elif sp_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=sp_group,
grad_scale=1 / sp_size,
)
# decoder layers
all_hidden_states = () if output_hidden_states else None
@@ -241,23 +245,10 @@ class Qwen2PipelineForwards:
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
if shard_config.enable_sequence_parallelism:
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states = gather_sp_output(hidden_states, shard_config)
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
@@ -351,15 +342,18 @@ class Qwen2PipelineForwards:
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
force_sp_output_gather=False,
)
past_key_values = None
if stage_manager.is_last_stage():
hidden_states = outputs[0]
if hidden_states.shape[1] == 2:
pass
logits = self.lm_head(hidden_states)
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
)
loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -541,7 +535,6 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
@@ -635,6 +628,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
force_sp_output_gather: bool = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -750,14 +744,9 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
if shard_config.enable_sequence_parallelism:
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states = gather_sp_output(hidden_states, shard_config)
# add hidden states from the last decoder layer
if output_hidden_states:
@@ -834,14 +823,15 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
force_sp_output_gather=False,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
)
loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
if not return_dict:
output = (logits,) + outputs[1:]