[shardformer] integrate with data parallelism (#4103)

This commit is contained in:
Frank Lee
2023-06-30 09:58:08 +08:00
parent f3b6aaa6b7
commit 6a88bae4ec
11 changed files with 97 additions and 50 deletions

View File

@@ -42,7 +42,7 @@ def check_bert(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()