mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[Shardformer] Downstream bert (#3979)
* add dist dropout in model * update docstring and bert policy with dropout * refactor basepolicy and sharded, update bert * update format * update gpt2 policy * update bert policy * remove unused code * update readme for new policy usage * add downstream model of bert * remove unused code
This commit is contained in:
@@ -10,11 +10,31 @@ def build_policies():
|
||||
"""
|
||||
auto_policy_dict = {}
|
||||
|
||||
from transformers import BertModel
|
||||
|
||||
from .bert import BertModelPolicy
|
||||
auto_policy_dict[BertModel] = BertModelPolicy
|
||||
|
||||
from transformers import BertForPreTraining
|
||||
|
||||
from .bert import BertForPretrainingPolicy
|
||||
auto_policy_dict[BertForPreTraining] = BertForPretrainingPolicy
|
||||
|
||||
from transformers import BertLMHeadModel
|
||||
|
||||
from .bert import BertLMHeadModelPolicy
|
||||
auto_policy_dict[BertLMHeadModel] = BertLMHeadModelPolicy
|
||||
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
from .bert import BertForMaskedLMPolicy
|
||||
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
|
||||
|
||||
from transformers import BertForNextSentencePrediction
|
||||
|
||||
from .bert import BertForNextSentencePredictionPolicy
|
||||
auto_policy_dict[BertForNextSentencePrediction] = BertForNextSentencePredictionPolicy
|
||||
|
||||
from transformers import BertForSequenceClassification
|
||||
|
||||
from .bert import BertForSequenceClassificationPolicy
|
||||
@@ -34,6 +54,11 @@ def build_policies():
|
||||
from .llama import LlamaForCausalLMPolicy
|
||||
auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
|
||||
|
||||
from transformers import BertForMultipleChoice
|
||||
|
||||
from .bert import BertForMultipleChoicePolicy
|
||||
auto_policy_dict[BertForMultipleChoice] = BertForMultipleChoicePolicy
|
||||
|
||||
from transformers import GPT2Model
|
||||
|
||||
from .gpt2 import GPT2Policy
|
||||
|
||||
@@ -35,12 +35,6 @@ class BertPolicy(Policy):
|
||||
]),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def binding_policy():
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def attn_in():
|
||||
return [
|
||||
@@ -148,30 +142,6 @@ class BertPolicy(Policy):
|
||||
replace_layer=col_nn.VocabParallelEmbedding1D,
|
||||
)]
|
||||
|
||||
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
|
||||
|
||||
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
||||
BertForMaskedLMPolicy.unembedding,
|
||||
]),
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def inject_policy():
|
||||
# return (BertForMaskedLM, BertForMaskedLM_)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def unembedding():
|
||||
return [
|
||||
@@ -185,8 +155,112 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||
]
|
||||
|
||||
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
# BertModel
|
||||
class BertModelPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
return BertPolicy.argument_policy(config, world_size)
|
||||
|
||||
|
||||
# BertForPretraining
|
||||
class BertForPretrainingPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
||||
BertPolicy.unembedding,
|
||||
]),
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def inject_policy():
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def binding_policy():
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
|
||||
|
||||
# BertForMaskedLM
|
||||
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
|
||||
|
||||
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
||||
BertPolicy.unembedding,
|
||||
]),
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def inject_policy():
|
||||
# return (BertForMaskedLM, BertForMaskedLM_)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def binding_policy():
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
|
||||
|
||||
# BertLMHeadModel
|
||||
class BertLMHeadModelPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
||||
BertPolicy.unembedding,
|
||||
]),
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def inject_policy():
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def binding_policy():
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
|
||||
|
||||
# BertForNextSentencePrediction
|
||||
class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
return BertPolicy.argument_policy(config, world_size)
|
||||
|
||||
|
||||
# BertForSequenceClassification
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
return BertPolicy.argument_policy(config, world_size)
|
||||
|
||||
|
||||
# BertForMultipleChoice
|
||||
class BertForMultipleChoicePolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
return BertPolicy.argument_policy(config, world_size)
|
||||
|
||||
@@ -13,6 +13,6 @@ class ShardConfig:
|
||||
world_size (int): The world size of the distributed process
|
||||
gather_output (bool): Whether to gather the output of the model of the last layer
|
||||
"""
|
||||
rank: int
|
||||
world_size: int = 2
|
||||
rank: int = None
|
||||
world_size: int = None
|
||||
gather_output: bool = True
|
||||
|
||||
@@ -276,6 +276,7 @@ def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Poli
|
||||
shard_config (`ShardConfig`): the config for distribute information
|
||||
policy (`Policy`): the custom policy for sharding
|
||||
"""
|
||||
# TODO: init shard_config automatically
|
||||
sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy)
|
||||
sharder.shard()
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user