[shardformer] Add dropout layer in shard model and refactor policy api (#3949)

* add dist dropout in model

* update docstring and bert policy with dropout

* refactor basepolicy and sharded, update bert

* update format

* update gpt2 policy

* update bert policy

* remove unused code

* update readme for new policy usage
This commit is contained in:
FoolPlayer
2023-06-12 16:52:18 +08:00
committed by Frank Lee
parent a73130482d
commit 45927d5527
7 changed files with 266 additions and 197 deletions

View File

@@ -5,7 +5,7 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be
import colossalai.shardformer.layer.layers as col_nn
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
from .basepolicy import Argument, Col_Layer, Dropout_Layer, Policy, Row_Layer
class BertPolicy(Policy):
@@ -28,123 +28,126 @@ class BertPolicy(Policy):
Argument(
attr_dict={
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
"word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size,
},
param_funcs=[
BertPolicy.embedding,
]),
BertLMPredictionHead:
Argument(
attr_dict={
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
},
param_funcs=[
BertPolicy.unembedding,
])
}
@staticmethod
def binding_policy() -> Dict:
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
@staticmethod
def attn_in() -> List:
def attn_in():
return [
Col_Layer(
weight="attention.self.query.weight",
bias="attention.self.query.bias",
suffix="attention.self.query",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
weight="attention.self.key.weight",
bias="attention.self.key.bias",
suffix="attention.self.key",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
weight="attention.self.value.weight",
bias="attention.self.value.bias",
suffix="attention.self.value",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Dropout_Layer(
suffix="attention.self.dropout",
p="p",
replace_layer=col_nn.Dropout1D,
),
Col_Layer(
weight="crossattention.self.query.weight",
bias="crossattention.self.query.bias",
suffix="crossattention.self.query",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
Col_Layer(
weight="crossattention.self.key.weight",
bias="crossattention.self.key.bias",
suffix="crossattention.self.key",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
Col_Layer(
weight="crossattention.self.value.weight",
bias="crossattention.self.value.bias",
suffix="crossattention.self.value",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
]
@staticmethod
def attn_out() -> List:
def attn_out():
return [
Row_Layer(
weight="attention.output.dense.weight",
bias="attention.output.dense.bias",
suffix="attention.output.dense",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
),
Dropout_Layer(
suffix="attention.output.dropout",
p="p",
replace_layer=col_nn.Dropout1D,
),
Row_Layer(
weight="crossattention.output.dense.weight",
bias="crossattention.output.dense.bias",
suffix="crossattention.output.dense",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
ignore=True,
),
]
@staticmethod
def mlp_in() -> List:
def mlp_in():
return [
Col_Layer(
weight="intermediate.dense.weight",
bias="intermediate.dense.bias",
suffix="intermediate.dense",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
]
@staticmethod
def mlp_out() -> List:
def mlp_out():
return [
Row_Layer(
weight="output.dense.weight",
bias="output.dense.bias",
suffix="output.dense",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
),
]
@staticmethod
def embedding() -> List:
return [Col_Layer(
weight="word_embeddings.weight",
replace_layer=col_nn.VocabParallelEmbedding1D,
)]
@staticmethod
def unembedding() -> List:
return [
Col_Layer(
weight="decoder.weight",
bias="decoder.bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
Dropout_Layer(
suffix="output.dropout",
p="p",
replace_layer=col_nn.Dropout1D,
)
]
@staticmethod
def embedding():
return [Col_Layer(
suffix="word_embeddings",
weight="weight",
replace_layer=col_nn.VocabParallelEmbedding1D,
)]
from transformers import BertForMaskedLM
@@ -154,18 +157,36 @@ from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
class BertForMaskedLMPolicy(BertPolicy):
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
def argument_policy(config, world_size):
base_argument = BertPolicy.argument_policy(config, world_size)
argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertForMaskedLMPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument
@staticmethod
def inject_policy():
# return (BertForMaskedLM, BertForMaskedLM_)
return None
@staticmethod
def unembedding():
return [
Col_Layer(
suffix="decoder",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
)
]
class BertForSequenceClassificationPolicy(BertPolicy):
@staticmethod
def inject_policy() -> Dict:
return {}
# model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# _ = BertForMaskedLMPolicy(model)
# print(isinstance(model,list(_.inject_policy().keys())[0]))
def inject_policy():
return None