mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 04:55:25 +00:00
[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)
* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915) * Add layer norm gradients all-reduce for sequence parallel. * skip pipeline inference test * [hotfix] fixing polices of sequence parallel (#4922) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy --------- Co-authored-by: littsk <1214689160@qq.com> * Hotfix/add grad all reduce for sequence parallel (#4927) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy * fix bug using wrong variables --------- Co-authored-by: littsk <1214689160@qq.com> * fix policy initialization * fix bloom and chatglm policices * polish code of handling layernorm * fix moe module * polish code of class initializing --------- Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Callable, Dict, List
|
||||
import torch.nn as nn
|
||||
from torch import Tensor, nn
|
||||
|
||||
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
||||
from .._utils import getattr_
|
||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
@@ -42,6 +42,12 @@ class OPTPolicy(Policy):
|
||||
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
|
||||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = FusedLayerNorm
|
||||
else:
|
||||
norm_cls = LayerNorm
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
@@ -94,26 +100,25 @@ class OPTPolicy(Policy):
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
|
||||
),
|
||||
policy=policy,
|
||||
target_key=OPTDecoder,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
|
||||
),
|
||||
policy=policy,
|
||||
target_key=OPTDecoder,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OPTDecoderLayer,
|
||||
)
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OPTDecoderLayer,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
@@ -183,9 +188,6 @@ class OPTPolicy(Policy):
|
||||
|
||||
|
||||
class OPTModelPolicy(OPTPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.opt.modeling_opt import OPTModel
|
||||
|
||||
@@ -253,9 +255,6 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
||||
|
||||
|
||||
class OPTForSequenceClassificationPolicy(OPTPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.opt.modeling_opt import OPTForSequenceClassification
|
||||
|
||||
@@ -281,9 +280,6 @@ class OPTForSequenceClassificationPolicy(OPTPolicy):
|
||||
|
||||
|
||||
class OPTForQuestionAnsweringPolicy(OPTPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.opt.modeling_opt import OPTForQuestionAnswering
|
||||
|
||||
|
Reference in New Issue
Block a user