mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 12:14:02 +00:00
[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)
* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915) * Add layer norm gradients all-reduce for sequence parallel. * skip pipeline inference test * [hotfix] fixing polices of sequence parallel (#4922) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy --------- Co-authored-by: littsk <1214689160@qq.com> * Hotfix/add grad all reduce for sequence parallel (#4927) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy * fix bug using wrong variables --------- Co-authored-by: littsk <1214689160@qq.com> * fix policy initialization * fix bloom and chatglm policices * polish code of handling layernorm * fix moe module * polish code of class initializing --------- Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
@@ -52,6 +52,11 @@ class WhisperPolicy(Policy):
|
||||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn(
|
||||
@@ -161,62 +166,61 @@ class WhisperPolicy(Policy):
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# Handle encoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperEncoderLayer,
|
||||
)
|
||||
# Handle encoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperEncoderLayer,
|
||||
)
|
||||
|
||||
# Handle decoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperDecoderLayer,
|
||||
)
|
||||
# Handle decoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperDecoderLayer,
|
||||
)
|
||||
|
||||
# handle encoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperEncoder,
|
||||
)
|
||||
# handle encoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=norm_cls,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperEncoder,
|
||||
)
|
||||
|
||||
# handle decoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperDecoder,
|
||||
)
|
||||
# handle decoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=norm_cls,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperDecoder,
|
||||
)
|
||||
|
||||
# enable flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
@@ -416,9 +420,6 @@ class WhisperPolicy(Policy):
|
||||
|
||||
# WhisperModel
|
||||
class WhisperModelPolicy(WhisperPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import WhisperModel
|
||||
|
||||
@@ -441,9 +442,6 @@ class WhisperModelPolicy(WhisperPolicy):
|
||||
|
||||
# WhisperForConditionalGeneration
|
||||
class WhisperForConditionalGenerationPolicy(WhisperPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import WhisperForConditionalGeneration
|
||||
|
||||
@@ -502,9 +500,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
|
||||
|
||||
# WhisperForAudioClassification
|
||||
class WhisperForAudioClassificationPolicy(WhisperPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def preprocess(self):
|
||||
return self.model
|
||||
|
||||
|
Reference in New Issue
Block a user