[shardformer] add gpt2 test and layer class refactor (#4041)

* add gpt2 test and layer class refactor

* add dropout in gpt2 policy
This commit is contained in:
FoolPlayer
2023-06-20 11:45:16 +08:00
committed by Frank Lee
parent d857f3dbba
commit 4021b9a8a2
14 changed files with 1400 additions and 840 deletions

View File

@@ -108,7 +108,7 @@ def check_bert(rank, world_size, port):
backward_lsit = [BertForMaskedLM, BertLMHeadModel]
for model_fn in forward_list:
org_model, sharded_model = build_model(model_fn)
org_model, sharded_model = build_model(world_size, model_fn)
check_forward(org_model, sharded_model)
if model_fn in backward_lsit: