mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)
* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915) * Add layer norm gradients all-reduce for sequence parallel. * skip pipeline inference test * [hotfix] fixing polices of sequence parallel (#4922) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy --------- Co-authored-by: littsk <1214689160@qq.com> * Hotfix/add grad all reduce for sequence parallel (#4927) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy * fix bug using wrong variables --------- Co-authored-by: littsk <1214689160@qq.com> * fix policy initialization * fix bloom and chatglm policices * polish code of handling layernorm * fix moe module * polish code of class initializing --------- Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from colossalai.shardformer.layer import (
|
||||
FusedRMSNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
RMSNorm,
|
||||
VocabParallelEmbedding1D,
|
||||
)
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
|
||||
@@ -58,6 +59,11 @@ class T5BasePolicy(Policy):
|
||||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = FusedRMSNorm
|
||||
else:
|
||||
norm_cls = RMSNorm
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
@@ -169,38 +175,37 @@ class T5BasePolicy(Policy):
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5LayerFF,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5LayerFF,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5LayerSelfAttention,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5LayerCrossAttention,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5Stack,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5LayerFF,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5LayerFF,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
|
||||
policy=policy,
|
||||
target_key=T5LayerSelfAttention,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
|
||||
policy=policy,
|
||||
target_key=T5LayerCrossAttention,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=norm_cls),
|
||||
policy=policy,
|
||||
target_key=T5Stack,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
@@ -363,9 +368,6 @@ class T5BasePolicy(Policy):
|
||||
|
||||
|
||||
class T5ModelPolicy(T5BasePolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5Model
|
||||
|
||||
@@ -402,9 +404,6 @@ class T5ModelPolicy(T5BasePolicy):
|
||||
|
||||
|
||||
class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
@@ -466,9 +465,6 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
|
||||
|
||||
class T5EncoderPolicy(T5BasePolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
|
Reference in New Issue
Block a user