[Feature] Split cross-entropy computation in SP (#5959)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* adapt chatglm, command-R, qwen

* debug

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* add comments

* q1 index only once

* remove events to simplify stream sync

* simplify forward/backward logic

* 2d ring forward passed

* 2d ring backward passed

* fixes

* fix ring attn loss

* 2D ring backward + llama passed

* merge

* update logger

* fix typo

* rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* remove typos

* fixes

* support GPT

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Wenxuan Tan
2024-09-10 12:06:50 +08:00
committed by GitHub
parent b3db1058ec
commit 8fd25d6e09
25 changed files with 527 additions and 1173 deletions

View File

@@ -6,14 +6,7 @@ from torch import Tensor, nn
import colossalai.shardformer.layer as col_nn
from ..modeling.gpt2 import (
GPT2PipelineForwards,
get_gpt2_flash_attention_forward,
get_gpt_model_forward_for_flash_attn,
get_jit_fused_gpt2_mlp_forward,
get_lm_forward_with_dist_cross_entropy,
gpt2_sequence_parallel_forward_fn,
)
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, get_jit_fused_gpt2_mlp_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@@ -71,18 +64,10 @@ class GPT2Policy(Policy):
warnings.warn(
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
)
sp_mode = "split_gather"
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode in ["split_gather", "ring"]
use_flash_attention = self.shard_config.enable_flash_attention
# todo: currently sp cannot be used with flashattention
if sp_mode in ["split_gather", "ring", "all_to_all"]:
if use_flash_attention:
warnings.warn(
f"Sequence parallelism mode {sp_mode} cannot be used with FlashAttention, will disable FlashAttention automatically."
)
self.shard_config.enable_flash_attention = False
use_flash_attention = False
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@@ -211,18 +196,16 @@ class GPT2Policy(Policy):
if use_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_gpt2_flash_attention_forward(),
"forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config),
},
policy=policy,
target_key=attn_cls,
)
if not self.shard_config.pipeline_stage_manager:
policy[GPT2Model].method_replacement = {
"forward": get_gpt_model_forward_for_flash_attn(self.shard_config)
}
if sp_mode is not None:
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism:
policy[GPT2Model].method_replacement = {
"forward": partial(GPT2PipelineForwards.gpt2_model_forward, shard_config=self.shard_config)
}
return policy
@@ -328,40 +311,39 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
module_policy = super().module_policy()
module_policy[GPT2LMHeadModel] = ModulePolicyDescription()
if self.shard_config.enable_tensor_parallelism:
addon_module = {
GPT2LMHeadModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.VocabParallelLMHead1D,
kwargs={
"gather_output": False,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
)
],
)
}
if self.shard_config.parallel_output:
addon_module[GPT2LMHeadModel].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.VocabParallelLMHead1D,
kwargs={
"gather_output": False,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
),
policy=module_policy,
target_key=GPT2LMHeadModel,
)
else:
addon_module = {
GPT2LMHeadModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
)
]
)
}
module_policy.update(addon_module)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=module_policy,
target_key=GPT2LMHeadModel,
)
if self.shard_config.parallel_output:
self.append_or_create_method_replacement(
description={
"forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config)
},
policy=module_policy,
target_key=GPT2LMHeadModel,
)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(