[test] update shardformer tests

This commit is contained in:
ver217
2023-07-05 14:30:17 +08:00
committed by Hongxin Liu
parent b0b8ad2823
commit 2d6cc07feb
2 changed files with 3 additions and 3 deletions

View File

@@ -12,8 +12,8 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
enable_tensor_parallelism=enable_tensor_parallelism)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model = shard_former.optimize(model_copy).cuda()
return org_model, sharded_model
sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model, sharded_model.cuda()
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):