[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)

* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915)

* Add layer norm gradients all-reduce for sequence parallel.

* skip pipeline inference test

* [hotfix] fixing polices of sequence parallel (#4922)

* Add layer norm gradients all-reduce for sequence parallel.

* fix parameter passing when calling get_autopolicy

---------

Co-authored-by: littsk <1214689160@qq.com>

* Hotfix/add grad all reduce for sequence parallel (#4927)

* Add layer norm gradients all-reduce for sequence parallel.


* fix parameter passing when calling get_autopolicy

* fix bug using wrong variables

---------

Co-authored-by: littsk <1214689160@qq.com>

* fix policy initialization

* fix bloom and chatglm policices

* polish code of handling layernorm

* fix moe module

* polish code of class initializing

---------

Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
littsk
2023-11-03 13:32:43 +08:00
committed by GitHub
parent d99b2c961a
commit 1a3315e336
30 changed files with 1120 additions and 552 deletions

View File

@@ -42,6 +42,10 @@ class BloomPolicy(Policy):
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
@@ -97,38 +101,39 @@ class BloomPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# handle bloom model
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="word_embeddings_layernorm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=BloomModel,
)
# handle bloom model
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="ln_f",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="word_embeddings_layernorm",
target_module=norm_cls,
),
],
policy=policy,
target_key=BloomModel,
)
# handle bloom block
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=BloomBlock,
)
# handle bloom block
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
],
policy=policy,
target_key=BloomBlock,
)
if use_sequence_parallel:
self.append_or_create_method_replacement(
@@ -225,9 +230,6 @@ class BloomPolicy(Policy):
class BloomModelPolicy(BloomPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
from transformers.models.bloom.modeling_bloom import BloomModel