From 914355604d50553935cf0e273bd6de8168b56e1a Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 5 Jul 2023 14:30:17 +0800 Subject: [PATCH] [test] update shardformer tests --- tests/test_shardformer/test_model/_utils.py | 4 ++-- tests/test_shardformer/test_with_torch_ddp.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index d83d9ecd3..e03014f3f 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -12,8 +12,8 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle enable_tensor_parallelism=enable_tensor_parallelism) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) - sharded_model = shard_former.optimize(model_copy).cuda() - return org_model, sharded_model + sharded_model, shared_params = shard_former.optimize(model_copy) + return org_model, sharded_model.cuda() def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py index 9f8a5db6c..f29c8d6f6 100644 --- a/tests/test_shardformer/test_with_torch_ddp.py +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -44,7 +44,7 @@ def check_shardformer_with_ddp(rank, world_size, port): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): # create and shard model model = model_fn().cuda() - sharded_model = shardformer.optimize(model) + sharded_model, _ = shardformer.optimize(model) # add ddp sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group)