[test] update shardformer tests

This commit is contained in:
ver217
2023-07-05 14:30:17 +08:00
committed by Hongxin Liu
parent b0b8ad2823
commit 2d6cc07feb
2 changed files with 3 additions and 3 deletions

View File

@@ -44,7 +44,7 @@ def check_shardformer_with_ddp(rank, world_size, port):
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create and shard model
model = model_fn().cuda()
sharded_model = shardformer.optimize(model)
sharded_model, _ = shardformer.optimize(model)
# add ddp
sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group)