diff --git a/tests/test_zero_data_parallel/test_init_context.py b/tests/test_zero_data_parallel/test_init_context.py index b3f214d98..9a0b72a10 100644 --- a/tests/test_zero_data_parallel/test_init_context.py +++ b/tests/test_zero_data_parallel/test_init_context.py @@ -33,11 +33,11 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -def test_zero_init_context(): - world_size = 2 +@pytest.mark.parametrize("world_size", [1, 2, 4]) +def test_zero_init_context(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - test_zero_init_context() + test_zero_init_context(2)