mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[Feature] Split cross-entropy computation in SP (#5959)
* halfway * fix cross-PP-stage position id length diff bug * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * adapt chatglm, command-R, qwen * debug * halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * add comments * q1 index only once * remove events to simplify stream sync * simplify forward/backward logic * 2d ring forward passed * 2d ring backward passed * fixes * fix ring attn loss * 2D ring backward + llama passed * merge * update logger * fix typo * rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * remove typos * fixes * support GPT --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -6,14 +6,7 @@ from torch import Tensor, nn
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from ..modeling.gpt2 import (
|
||||
GPT2PipelineForwards,
|
||||
get_gpt2_flash_attention_forward,
|
||||
get_gpt_model_forward_for_flash_attn,
|
||||
get_jit_fused_gpt2_mlp_forward,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
gpt2_sequence_parallel_forward_fn,
|
||||
)
|
||||
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, get_jit_fused_gpt2_mlp_forward
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
@@ -71,18 +64,10 @@ class GPT2Policy(Policy):
|
||||
warnings.warn(
|
||||
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
|
||||
)
|
||||
sp_mode = "split_gather"
|
||||
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
use_flash_attention = self.shard_config.enable_flash_attention
|
||||
# todo: currently sp cannot be used with flashattention
|
||||
if sp_mode in ["split_gather", "ring", "all_to_all"]:
|
||||
if use_flash_attention:
|
||||
warnings.warn(
|
||||
f"Sequence parallelism mode {sp_mode} cannot be used with FlashAttention, will disable FlashAttention automatically."
|
||||
)
|
||||
self.shard_config.enable_flash_attention = False
|
||||
use_flash_attention = False
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
@@ -211,18 +196,16 @@ class GPT2Policy(Policy):
|
||||
if use_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_gpt2_flash_attention_forward(),
|
||||
"forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if not self.shard_config.pipeline_stage_manager:
|
||||
policy[GPT2Model].method_replacement = {
|
||||
"forward": get_gpt_model_forward_for_flash_attn(self.shard_config)
|
||||
}
|
||||
|
||||
if sp_mode is not None:
|
||||
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
|
||||
if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism:
|
||||
policy[GPT2Model].method_replacement = {
|
||||
"forward": partial(GPT2PipelineForwards.gpt2_model_forward, shard_config=self.shard_config)
|
||||
}
|
||||
|
||||
return policy
|
||||
|
||||
@@ -328,40 +311,39 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
module_policy = super().module_policy()
|
||||
|
||||
module_policy[GPT2LMHeadModel] = ModulePolicyDescription()
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
GPT2LMHeadModel: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.VocabParallelLMHead1D,
|
||||
kwargs={
|
||||
"gather_output": False,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
}
|
||||
if self.shard_config.parallel_output:
|
||||
addon_module[GPT2LMHeadModel].method_replacement = {
|
||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||
}
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.VocabParallelLMHead1D,
|
||||
kwargs={
|
||||
"gather_output": False,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
},
|
||||
),
|
||||
policy=module_policy,
|
||||
target_key=GPT2LMHeadModel,
|
||||
)
|
||||
else:
|
||||
addon_module = {
|
||||
GPT2LMHeadModel: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.PaddingLMHead,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.PaddingLMHead,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=module_policy,
|
||||
target_key=GPT2LMHeadModel,
|
||||
)
|
||||
|
||||
if self.shard_config.parallel_output:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config)
|
||||
},
|
||||
policy=module_policy,
|
||||
target_key=GPT2LMHeadModel,
|
||||
)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(
|
||||
|
Reference in New Issue
Block a user