[shardformer] Add dropout layer in shard model and refactor policy api (#3949)

* add dist dropout in model

* update docstring and bert policy with dropout

* refactor basepolicy and sharded, update bert

* update format

* update gpt2 policy

* update bert policy

* remove unused code

* update readme for new policy usage
This commit is contained in:
FoolPlayer
2023-06-12 16:52:18 +08:00
committed by Frank Lee
parent a73130482d
commit 45927d5527
7 changed files with 266 additions and 197 deletions

View File

@@ -40,19 +40,22 @@ class GPT2Policy(Policy):
@staticmethod
def attn_in() -> List:
return [
Col_Layer(weight="attn.c_attn.weight",
bias="attn.c_attn.bias",
Col_Layer(suffix="attn.c_attn",
weight="weight",
bias="bias",
n_cast=3,
reversed=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(weight="crossattention.c_attn.weight",
bias="crossattention.c_attn.bias",
Col_Layer(suffix="crossattention.c_attn",
weight="weight",
bias="bias",
n_cast=2,
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(weight="crossattention.q_attn.weight",
bias="crossattention.q_attn.bias",
Col_Layer(suffix="crossattention.q_attn",
weight="weight",
bias="bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col)
@@ -61,12 +64,14 @@ class GPT2Policy(Policy):
@staticmethod
def attn_out() -> List:
return [
Row_Layer(weight="attn.c_proj.weight",
bias="attn.c_proj.bias",
Row_Layer(suffix="attn.c_proj",
weight="weight",
bias="bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row),
Row_Layer(weight="crossattention.c_proj.weight",
bias="crossattention.c_proj.bias",
Row_Layer(suffix="crossattention.c_proj",
weight="weight",
bias="bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Row)
@@ -75,21 +80,23 @@ class GPT2Policy(Policy):
@staticmethod
def mlp_in() -> List:
return [
Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col),
Col_Layer(suffix="mlp.c_fc", weight="weight", bias="bias", reversed=True,
replace_layer=col_nn.Linear1D_Col),
]
@staticmethod
def mlp_out() -> List:
return [
Row_Layer(weight="mlp.c_proj.weight",
bias="mlp.c_proj.bias",
Row_Layer(suffix="mlp.c_proj",
weight="weight",
bias="bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row)
]
@staticmethod
def embedding() -> List:
return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
return [Col_Layer(suffix="wte", weight="weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
from transformers import GPT2LMHeadModel
@@ -111,8 +118,9 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
@staticmethod
def unembedding() -> List:
return [
Col_Layer(weight="lm_head.weight",
bias="lm_head.bias",
Col_Layer(suffix="lm_head",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True)
]