mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[shardformer] add gpt2 test and layer class refactor (#4041)
* add gpt2 test and layer class refactor * add dropout in gpt2 policy
This commit is contained in:
@@ -108,7 +108,7 @@ def check_bert(rank, world_size, port):
|
||||
backward_lsit = [BertForMaskedLM, BertLMHeadModel]
|
||||
|
||||
for model_fn in forward_list:
|
||||
org_model, sharded_model = build_model(model_fn)
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
check_forward(org_model, sharded_model)
|
||||
|
||||
if model_fn in backward_lsit:
|
||||
|
Reference in New Issue
Block a user