mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[shardformer] Add dropout layer in shard model and refactor policy api (#3949)
* add dist dropout in model * update docstring and bert policy with dropout * refactor basepolicy and sharded, update bert * update format * update gpt2 policy * update bert policy * remove unused code * update readme for new policy usage
This commit is contained in:
@@ -40,19 +40,22 @@ class GPT2Policy(Policy):
|
||||
@staticmethod
|
||||
def attn_in() -> List:
|
||||
return [
|
||||
Col_Layer(weight="attn.c_attn.weight",
|
||||
bias="attn.c_attn.bias",
|
||||
Col_Layer(suffix="attn.c_attn",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
n_cast=3,
|
||||
reversed=True,
|
||||
replace_layer=col_nn.Linear1D_Col),
|
||||
Col_Layer(weight="crossattention.c_attn.weight",
|
||||
bias="crossattention.c_attn.bias",
|
||||
Col_Layer(suffix="crossattention.c_attn",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
n_cast=2,
|
||||
reversed=True,
|
||||
ignore=True,
|
||||
replace_layer=col_nn.Linear1D_Col),
|
||||
Col_Layer(weight="crossattention.q_attn.weight",
|
||||
bias="crossattention.q_attn.bias",
|
||||
Col_Layer(suffix="crossattention.q_attn",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
reversed=True,
|
||||
ignore=True,
|
||||
replace_layer=col_nn.Linear1D_Col)
|
||||
@@ -61,12 +64,14 @@ class GPT2Policy(Policy):
|
||||
@staticmethod
|
||||
def attn_out() -> List:
|
||||
return [
|
||||
Row_Layer(weight="attn.c_proj.weight",
|
||||
bias="attn.c_proj.bias",
|
||||
Row_Layer(suffix="attn.c_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
reversed=True,
|
||||
replace_layer=col_nn.Linear1D_Row),
|
||||
Row_Layer(weight="crossattention.c_proj.weight",
|
||||
bias="crossattention.c_proj.bias",
|
||||
Row_Layer(suffix="crossattention.c_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
reversed=True,
|
||||
ignore=True,
|
||||
replace_layer=col_nn.Linear1D_Row)
|
||||
@@ -75,21 +80,23 @@ class GPT2Policy(Policy):
|
||||
@staticmethod
|
||||
def mlp_in() -> List:
|
||||
return [
|
||||
Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col),
|
||||
Col_Layer(suffix="mlp.c_fc", weight="weight", bias="bias", reversed=True,
|
||||
replace_layer=col_nn.Linear1D_Col),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def mlp_out() -> List:
|
||||
return [
|
||||
Row_Layer(weight="mlp.c_proj.weight",
|
||||
bias="mlp.c_proj.bias",
|
||||
Row_Layer(suffix="mlp.c_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
reversed=True,
|
||||
replace_layer=col_nn.Linear1D_Row)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def embedding() -> List:
|
||||
return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
|
||||
return [Col_Layer(suffix="wte", weight="weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
|
||||
|
||||
|
||||
from transformers import GPT2LMHeadModel
|
||||
@@ -111,8 +118,9 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||
@staticmethod
|
||||
def unembedding() -> List:
|
||||
return [
|
||||
Col_Layer(weight="lm_head.weight",
|
||||
bias="lm_head.bias",
|
||||
Col_Layer(suffix="lm_head",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
gather_output=True)
|
||||
]
|
||||
|
Reference in New Issue
Block a user