mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 20:23:26 +00:00
[shardformer] Add dropout layer in shard model and refactor policy api (#3949)
* add dist dropout in model * update docstring and bert policy with dropout * refactor basepolicy and sharded, update bert * update format * update gpt2 policy * update bert policy * remove unused code * update readme for new policy usage
This commit is contained in:
@@ -5,7 +5,7 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be
|
||||
|
||||
import colossalai.shardformer.layer.layers as col_nn
|
||||
|
||||
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
|
||||
from .basepolicy import Argument, Col_Layer, Dropout_Layer, Policy, Row_Layer
|
||||
|
||||
|
||||
class BertPolicy(Policy):
|
||||
@@ -28,123 +28,126 @@ class BertPolicy(Policy):
|
||||
Argument(
|
||||
attr_dict={
|
||||
# 1. shard vocab size
|
||||
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
|
||||
# 2. add the size of the sliced embedding layer excluding the last slice
|
||||
"word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size,
|
||||
},
|
||||
param_funcs=[
|
||||
BertPolicy.embedding,
|
||||
]),
|
||||
BertLMPredictionHead:
|
||||
Argument(
|
||||
attr_dict={
|
||||
# 1. shard vocab size
|
||||
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
|
||||
# 2. add the size of the sliced embedding layer excluding the last slice
|
||||
},
|
||||
param_funcs=[
|
||||
BertPolicy.unembedding,
|
||||
])
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def binding_policy() -> Dict:
|
||||
def binding_policy():
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def attn_in() -> List:
|
||||
def attn_in():
|
||||
return [
|
||||
Col_Layer(
|
||||
weight="attention.self.query.weight",
|
||||
bias="attention.self.query.bias",
|
||||
suffix="attention.self.query",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
weight="attention.self.key.weight",
|
||||
bias="attention.self.key.bias",
|
||||
suffix="attention.self.key",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
weight="attention.self.value.weight",
|
||||
bias="attention.self.value.bias",
|
||||
suffix="attention.self.value",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Dropout_Layer(
|
||||
suffix="attention.self.dropout",
|
||||
p="p",
|
||||
replace_layer=col_nn.Dropout1D,
|
||||
),
|
||||
Col_Layer(
|
||||
weight="crossattention.self.query.weight",
|
||||
bias="crossattention.self.query.bias",
|
||||
suffix="crossattention.self.query",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
ignore=True,
|
||||
),
|
||||
Col_Layer(
|
||||
weight="crossattention.self.key.weight",
|
||||
bias="crossattention.self.key.bias",
|
||||
suffix="crossattention.self.key",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
ignore=True,
|
||||
),
|
||||
Col_Layer(
|
||||
weight="crossattention.self.value.weight",
|
||||
bias="crossattention.self.value.bias",
|
||||
suffix="crossattention.self.value",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
ignore=True,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def attn_out() -> List:
|
||||
def attn_out():
|
||||
return [
|
||||
Row_Layer(
|
||||
weight="attention.output.dense.weight",
|
||||
bias="attention.output.dense.bias",
|
||||
suffix="attention.output.dense",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
),
|
||||
Dropout_Layer(
|
||||
suffix="attention.output.dropout",
|
||||
p="p",
|
||||
replace_layer=col_nn.Dropout1D,
|
||||
),
|
||||
Row_Layer(
|
||||
weight="crossattention.output.dense.weight",
|
||||
bias="crossattention.output.dense.bias",
|
||||
suffix="crossattention.output.dense",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
ignore=True,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def mlp_in() -> List:
|
||||
def mlp_in():
|
||||
return [
|
||||
Col_Layer(
|
||||
weight="intermediate.dense.weight",
|
||||
bias="intermediate.dense.bias",
|
||||
suffix="intermediate.dense",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def mlp_out() -> List:
|
||||
def mlp_out():
|
||||
return [
|
||||
Row_Layer(
|
||||
weight="output.dense.weight",
|
||||
bias="output.dense.bias",
|
||||
suffix="output.dense",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def embedding() -> List:
|
||||
return [Col_Layer(
|
||||
weight="word_embeddings.weight",
|
||||
replace_layer=col_nn.VocabParallelEmbedding1D,
|
||||
)]
|
||||
|
||||
@staticmethod
|
||||
def unembedding() -> List:
|
||||
return [
|
||||
Col_Layer(
|
||||
weight="decoder.weight",
|
||||
bias="decoder.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
gather_output=True,
|
||||
Dropout_Layer(
|
||||
suffix="output.dropout",
|
||||
p="p",
|
||||
replace_layer=col_nn.Dropout1D,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def embedding():
|
||||
return [Col_Layer(
|
||||
suffix="word_embeddings",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.VocabParallelEmbedding1D,
|
||||
)]
|
||||
|
||||
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
@@ -154,18 +157,36 @@ from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def inject_policy() -> Tuple[nn.Module, nn.Module]:
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
||||
BertForMaskedLMPolicy.unembedding,
|
||||
]),
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def inject_policy():
|
||||
# return (BertForMaskedLM, BertForMaskedLM_)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def unembedding():
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="decoder",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
gather_output=True,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def inject_policy() -> Dict:
|
||||
return {}
|
||||
|
||||
|
||||
# model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
||||
# _ = BertForMaskedLMPolicy(model)
|
||||
# print(isinstance(model,list(_.inject_policy().keys())[0]))
|
||||
def inject_policy():
|
||||
return None
|
||||
|
Reference in New Issue
Block a user