mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[shardformer] fix bert and gpt downstream with new api (#4024)
* fix bert downstream with new api * remove comment line
This commit is contained in:
@@ -4,41 +4,40 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be
|
||||
import colossalai.shardformer.layer.layers as col_nn
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
|
||||
from ..shard.shard_config import ShardConfig
|
||||
from ..utils import getattr_, setattr_
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
class BertPolicy(Policy):
|
||||
|
||||
def preprocess(self, shard_config: ShardConfig = None):
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
# TODO:
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = shard_config.tensor_parallel_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
return self.model
|
||||
|
||||
def module_policy(self, shard_config: ShardConfig = None):
|
||||
def module_policy(self):
|
||||
return {
|
||||
BertLayer:
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
# 1. shard hidden size
|
||||
"attention.self.all_head_size":
|
||||
self.model.config.hidden_size // shard_config.tensor_parallel_size,
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"crossattention.self.all_head_size":
|
||||
self.model.config.hidden_size // shard_config.tensor_parallel_size,
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
# 2. shard number of heads
|
||||
"attention.self.num_attention_heads":
|
||||
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"crossattention.self.num_attention_heads":
|
||||
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
@@ -100,13 +99,43 @@ class BertPolicy(Policy):
|
||||
return self.model
|
||||
|
||||
|
||||
# BertModel
|
||||
class BertModelPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
# BertForPreTraining
|
||||
class BertForPretrainingPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="decoder",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True})
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
|
||||
# BertForMaskedLM
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self, shard_config: ShardConfig = None):
|
||||
module_policy = super().module_policy(shard_config)
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
@@ -124,16 +153,41 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||
# BertLMHeadModel
|
||||
class BertLMHeadModelPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
||||
BertPolicy.unembedding,
|
||||
]),
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="decoder",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True})
|
||||
])
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
|
||||
# BertForNextSentencePrediction
|
||||
class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
# BertForSequenceClassification
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
# BertForMultipleChoice
|
||||
class BertForMultipleChoicePolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
Reference in New Issue
Block a user