mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)
* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915) * Add layer norm gradients all-reduce for sequence parallel. * skip pipeline inference test * [hotfix] fixing polices of sequence parallel (#4922) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy --------- Co-authored-by: littsk <1214689160@qq.com> * Hotfix/add grad all reduce for sequence parallel (#4927) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy * fix bug using wrong variables --------- Co-authored-by: littsk <1214689160@qq.com> * fix policy initialization * fix bloom and chatglm policices * polish code of handling layernorm * fix moe module * polish code of class initializing --------- Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
@@ -42,6 +42,10 @@ class BloomPolicy(Policy):
|
||||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
@@ -97,38 +101,39 @@ class BloomPolicy(Policy):
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# handle bloom model
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="ln_f",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings_layernorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BloomModel,
|
||||
)
|
||||
# handle bloom model
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="ln_f",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings_layernorm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BloomModel,
|
||||
)
|
||||
|
||||
# handle bloom block
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BloomBlock,
|
||||
)
|
||||
# handle bloom block
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BloomBlock,
|
||||
)
|
||||
|
||||
if use_sequence_parallel:
|
||||
self.append_or_create_method_replacement(
|
||||
@@ -225,9 +230,6 @@ class BloomPolicy(Policy):
|
||||
|
||||
|
||||
class BloomModelPolicy(BloomPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
from transformers.models.bloom.modeling_bloom import BloomModel
|
||||
|
Reference in New Issue
Block a user