support bert with new api

This commit is contained in:
FoolPlayer
2023-06-16 16:12:27 +08:00
committed by Frank Lee
parent 507c0ad368
commit df018fc305
2 changed files with 37 additions and 3 deletions

View File

@@ -2,6 +2,7 @@ import torch.nn as nn
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
import colossalai.shardformer.layer.layers as col_nn
from colossalai.shardformer.layer.dropout import Dropout1D
from ..shard.shard_config import ShardConfig
from ..utils import getattr_, setattr_
@@ -65,7 +66,24 @@ class BertPolicy(Policy):
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
),
])
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=Dropout1D,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=Dropout1D,
)
]),
BertEmbeddings:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
])
}
def new_model_class(self):
@@ -87,6 +105,21 @@ class BertForMaskedLMPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self, shard_config: ShardConfig = None):
module_policy = super().module_policy(shard_config)
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):