mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 20:46:00 +00:00
support bert with new api
This commit is contained in:
@@ -2,6 +2,7 @@ import torch.nn as nn
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
|
||||
|
||||
import colossalai.shardformer.layer.layers as col_nn
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
|
||||
from ..shard.shard_config import ShardConfig
|
||||
from ..utils import getattr_, setattr_
|
||||
@@ -65,7 +66,24 @@ class BertPolicy(Policy):
|
||||
suffix="output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
])
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.dropout",
|
||||
target_module=Dropout1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=Dropout1D,
|
||||
)
|
||||
]),
|
||||
BertEmbeddings:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
||||
def new_model_class(self):
|
||||
@@ -87,6 +105,21 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self, shard_config: ShardConfig = None):
|
||||
module_policy = super().module_policy(shard_config)
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="decoder",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True})
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
|
||||
# BertLMHeadModel
|
||||
class BertLMHeadModelPolicy(BertPolicy):
|
||||
|
Reference in New Issue
Block a user