From 067e18f7e98a8a47de350bb8af4d6e52aa42e78d Mon Sep 17 00:00:00 2001 From: hxwang Date: Mon, 22 Jul 2024 05:36:20 +0000 Subject: [PATCH] [test] fix test: test_zero1_2 --- colossalai/zero/low_level/low_level_optim.py | 2 +- tests/test_shardformer/test_model/test_shard_deepseek.py | 4 ++-- tests/test_shardformer/test_model/test_shard_mixtral.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 947dec51b..51d7d1eaa 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -880,7 +880,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): return None grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device) dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg) - return grad_flat[: working_param.numel()].reshape_as(working_param) + return grad_flat.view(-1)[: working_param.numel()].view_as(working_param) def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]: working_grads = [] diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py index 16513b2f5..c301777f2 100644 --- a/tests/test_shardformer/test_model/test_shard_deepseek.py +++ b/tests/test_shardformer/test_model/test_shard_deepseek.py @@ -179,9 +179,9 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() -def test_mistral(world_size): +def test_deepseek(world_size): spawn(run_dist, world_size) if __name__ == "__main__": - test_mistral(world_size=8) + test_deepseek(world_size=4) diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index 2b8623e13..419679797 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -180,9 +180,9 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() -def test_mistral(world_size): +def test_mixtral(world_size): spawn(run_dist, world_size) if __name__ == "__main__": - test_mistral(world_size=8) + test_mixtral(world_size=4)