[shardformer] Add layernorm (#4072)

* add layernorm to bert

* add layernorm test

* add layernorm test with load state dict

* add use_mixedfusedLN in shard config

* refactor policy to support fused_layernorm
This commit is contained in:
FoolPlayer
2023-06-23 18:00:22 +08:00
committed by Frank Lee
parent 70c58cfd4f
commit 92f6791095
7 changed files with 252 additions and 17 deletions

View File

@@ -1,8 +1,14 @@
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertForMultipleChoice,
BertForSequenceClassification,
BertForTokenClassification,
BertLayer,
BertLMPredictionHead,
)
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.layer.dropout import Dropout1D
from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -24,7 +30,7 @@ class BertPolicy(Policy):
return self.model
def module_policy(self):
return {
base_policy = {
BertLayer:
ModulePolicyDescription(
attribute_replacement={
@@ -53,10 +59,18 @@ class BertPolicy(Policy):
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=col_nn.Dropout1D,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.Dropout1D,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
@@ -66,12 +80,8 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=Dropout1D,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=Dropout1D,
suffix="output.dropout",
target_module=col_nn.Dropout1D,
)
]),
BertEmbeddings:
@@ -81,10 +91,32 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
)
])
}
if self.shard_config.fused_layernorm:
base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.LayerNorm1D,
))
base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="output.LayerNorm",
target_module=col_nn.LayerNorm1D,
))
base_policy[BertEmbeddings].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="LayerNorm",
target_module=col_nn.LayerNorm1D,
),)
return base_policy
def new_model_class(self):
# do nothing
return self.model
@@ -115,9 +147,15 @@ class BertForPretrainingPolicy(BertPolicy):
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D,
))
module_policy.update(addon_module)
return module_policy
@@ -146,9 +184,15 @@ class BertLMHeadModelPolicy(BertPolicy):
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D,
))
module_policy.update(addon_module)
return module_policy
@@ -177,9 +221,15 @@ class BertForMaskedLMPolicy(BertPolicy):
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D,
))
module_policy.update(addon_module)
return module_policy
@@ -199,6 +249,22 @@ class BertForSequenceClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertForSequenceClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
)
])
}
module_policy.update(addon_module)
return module_policy
# BertForTokenClassification
class BertForTokenClassificationPolicy(BertPolicy):
@@ -206,6 +272,22 @@ class BertForTokenClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertForTokenClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
)
])
}
module_policy.update(addon_module)
return module_policy
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
@@ -219,3 +301,19 @@ class BertForMultipleChoicePolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertForMultipleChoice:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
)
])
}
module_policy.update(addon_module)
return module_policy