diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index fc3e84473..fe74f83ca 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -2,6 +2,7 @@ import torch.nn as nn from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead import colossalai.shardformer.layer.layers as col_nn +from colossalai.shardformer.layer.dropout import Dropout1D from ..shard.shard_config import ShardConfig from ..utils import getattr_, setattr_ @@ -65,7 +66,24 @@ class BertPolicy(Policy): suffix="output.dense", target_module=col_nn.Linear1D_Row, ), - ]) + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=Dropout1D, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=Dropout1D, + ) + ]), + BertEmbeddings: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ]) } def new_model_class(self): @@ -87,6 +105,21 @@ class BertForMaskedLMPolicy(BertPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self, shard_config: ShardConfig = None): + module_policy = super().module_policy(shard_config) + addon_module = { + BertLMPredictionHead: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="decoder", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + return module_policy + # BertLMHeadModel class BertLMHeadModelPolicy(BertPolicy): diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index eb8300d59..5c8584595 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -171,12 +171,13 @@ class ModelSharder(object): for description in sub_module_replacement: suffix = description.suffix target_module = description.target_module - kwargs = description.kwargs + kwargs = {} if description.kwargs is None else description.kwargs assert target_module is not None, 'target_module should not be None' # TODO: support different parallel mode native_sub_module = getattr_(org_layer, suffix) - replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d']) + replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'], + **kwargs) setattr_(org_layer, suffix, replace_layer)