[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)

* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915)

* Add layer norm gradients all-reduce for sequence parallel.

* skip pipeline inference test

* [hotfix] fixing polices of sequence parallel (#4922)

* Add layer norm gradients all-reduce for sequence parallel.

* fix parameter passing when calling get_autopolicy

---------

Co-authored-by: littsk <1214689160@qq.com>

* Hotfix/add grad all reduce for sequence parallel (#4927)

* Add layer norm gradients all-reduce for sequence parallel.


* fix parameter passing when calling get_autopolicy

* fix bug using wrong variables

---------

Co-authored-by: littsk <1214689160@qq.com>

* fix policy initialization

* fix bloom and chatglm policices

* polish code of handling layernorm

* fix moe module

* polish code of class initializing

---------

Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
littsk
2023-11-03 13:32:43 +08:00
committed by GitHub
parent d99b2c961a
commit 1a3315e336
30 changed files with 1120 additions and 552 deletions

View File

@@ -5,7 +5,7 @@ from typing import Callable, Dict, List
import torch.nn as nn
from torch import Tensor, nn
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .._utils import getattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
@@ -42,6 +42,12 @@ class OPTPolicy(Policy):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = FusedLayerNorm
else:
norm_cls = LayerNorm
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
@@ -94,26 +100,25 @@ class OPTPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
),
policy=policy,
target_key=OPTDecoder,
)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
),
policy=policy,
target_key=OPTDecoder,
)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
),
SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
),
],
policy=policy,
target_key=OPTDecoderLayer,
)
SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
),
],
policy=policy,
target_key=OPTDecoderLayer,
)
# use flash attention
if self.shard_config.enable_flash_attention:
@@ -183,9 +188,6 @@ class OPTPolicy(Policy):
class OPTModelPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTModel
@@ -253,9 +255,6 @@ class OPTForCausalLMPolicy(OPTPolicy):
class OPTForSequenceClassificationPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForSequenceClassification
@@ -281,9 +280,6 @@ class OPTForSequenceClassificationPolicy(OPTPolicy):
class OPTForQuestionAnsweringPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForQuestionAnswering