[shardformer] add linearconv1d test (#4067)

* add linearconv1d test

* add linearconv1d test
This commit is contained in:
FoolPlayer
2023-06-22 14:40:37 +08:00
committed by Frank Lee
parent 8eb09a4c69
commit 0803a61412
4 changed files with 122 additions and 34 deletions

View File

@@ -44,29 +44,23 @@ class GPT2Policy(Policy):
suffix="attn.c_attn",
target_module=col_nn.LinearConv1D_Col,
kwargs={
"n_cast": 3,
"n_fused": 3,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.LinearConv1D_Row,
kwargs={
"n_cast": 1,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.LinearConv1D_Col,
kwargs={
"n_cast": 1,
"n_fused": 1,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.LinearConv1D_Row,
kwargs={
"n_cast": 1,
},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",