From 5a1a095b925a78f00185954bb3783c73086dbade Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 15 Apr 2022 00:33:04 +0800 Subject: [PATCH] [test] refactored with the new rerun decorator (#763) * [test] refactored with the new rerun decorator * polish test case --- tests/test_amp/test_naive_fp16.py | 4 ++-- tests/test_comm/test_comm.py | 4 ++-- tests/test_context/test_hybrid_parallel.py | 4 ++-- tests/test_data/test_data_parallel_sampler.py | 4 ++-- tests/test_data/test_deterministic_dataloader.py | 4 ++-- .../test_cifar_with_data_pipeline_tensor.py | 11 ++++++----- tests/test_engine/test_engine.py | 4 ++-- tests/test_layers/test_1d/test_1d.py | 4 ++-- tests/test_layers/test_2d/test_2d.py | 4 ++-- tests/test_layers/test_2p5d/test_2p5d.py | 4 ++-- tests/test_layers/test_3d/test_3d.py | 4 ++-- tests/test_layers/test_sequence/test_sequence.py | 4 ++-- tests/test_moe/test_grad_handler.py | 5 ++--- tests/test_moe/test_kernel.py | 4 ++-- tests/test_moe/test_moe_group.py | 4 ++-- tests/test_moe/test_moe_zero_init.py | 4 ++-- tests/test_moe/test_moe_zero_model.py | 4 ++-- tests/test_moe/test_moe_zero_optim.py | 4 ++-- .../test_trainer_with_non_pipe_schedule.py | 4 ++-- tests/test_trainer/test_trainer_with_pipe_schedule.py | 9 ++++++--- tests/test_utils/test_commons.py | 4 ++-- tests/test_utils/test_gradient_accumluation.py | 5 +++-- tests/test_utils/test_zero_gradient_clippling.py | 4 ++-- tests/test_zero/test_found_inf.py | 4 ++-- tests/test_zero/test_init_context.py | 4 ++-- tests/test_zero/test_mem_collector.py | 4 ++-- tests/test_zero/test_shard_model_v2.py | 4 ++-- tests/test_zero/test_shard_param.py | 7 +++---- tests/test_zero/test_sharded_optim_v2.py | 4 ++-- tests/test_zero/test_sharded_optim_with_sync_bn.py | 4 ++-- tests/test_zero/test_state_dict.py | 4 ++-- tests/test_zero/test_stateful_tensor_mgr.py | 6 +++--- tests/test_zero/test_tensor_utils.py | 2 ++ tests/test_zero/test_zero_engine.py | 6 +++--- 34 files changed, 80 insertions(+), 75 deletions(-) diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_amp/test_naive_fp16.py index 02f25dc99..95c5686ae 100644 --- a/tests/test_amp/test_naive_fp16.py +++ b/tests/test_amp/test_naive_fp16.py @@ -3,7 +3,7 @@ import colossalai import torch.multiprocessing as mp from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.testing import assert_close_loose, rerun_on_exception +from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp @@ -84,7 +84,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_naive_amp(): world_size = 1 run_func = partial(run_dist, world_size=world_size, port=free_port()) diff --git a/tests/test_comm/test_comm.py b/tests/test_comm/test_comm.py index 7ce56d9de..07cb67730 100644 --- a/tests/test_comm/test_comm.py +++ b/tests/test_comm/test_comm.py @@ -9,7 +9,7 @@ from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.utils import free_port, get_current_device -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) @@ -64,7 +64,7 @@ def check_layer(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_comm(): world_size = 4 run_func = partial(check_layer, world_size=world_size, port=free_port()) diff --git a/tests/test_context/test_hybrid_parallel.py b/tests/test_context/test_hybrid_parallel.py index 4a35be7df..f311b1d2e 100644 --- a/tests/test_context/test_hybrid_parallel.py +++ b/tests/test_context/test_hybrid_parallel.py @@ -13,7 +13,7 @@ from colossalai.core import global_context as gpc from colossalai.utils import free_port from colossalai.context import reset_seeds from colossalai.global_variables import tensor_parallel_env as tp_env -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py')) @@ -141,7 +141,7 @@ def run_dist(rank, world_size, backend, port_list, host): @pytest.mark.cpu -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_context(): """ As no computation or communication is done, we can run this test on CPU. diff --git a/tests/test_data/test_data_parallel_sampler.py b/tests/test_data/test_data_parallel_sampler.py index a3fa4ea5f..05967f7ce 100644 --- a/tests/test_data/test_data_parallel_sampler.py +++ b/tests/test_data/test_data_parallel_sampler.py @@ -17,7 +17,7 @@ from torchvision import transforms from colossalai.context import ParallelMode, Config from colossalai.core import global_context as gpc from colossalai.utils import get_dataloader, free_port -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use CONFIG = Config( dict( @@ -67,7 +67,7 @@ def run_data_sampler(rank, world_size, port): @pytest.mark.cpu -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_data_sampler(): world_size = 4 test_func = partial(run_data_sampler, world_size=world_size, port=free_port()) diff --git a/tests/test_data/test_deterministic_dataloader.py b/tests/test_data/test_deterministic_dataloader.py index cff6a410a..1ae4b0091 100644 --- a/tests/test_data/test_deterministic_dataloader.py +++ b/tests/test_data/test_deterministic_dataloader.py @@ -17,7 +17,7 @@ from colossalai.builder import build_dataset, build_transform from colossalai.context import ParallelMode, Config from colossalai.core import global_context as gpc from colossalai.utils import free_port -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use CONFIG = Config( dict( @@ -79,7 +79,7 @@ def run_data_sampler(rank, world_size, port): @pytest.mark.cpu -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_data_sampler(): world_size = 4 test_func = partial(run_data_sampler, world_size=world_size, port=free_port()) diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py index 6ac069135..ed5753d98 100644 --- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py @@ -15,7 +15,7 @@ from colossalai.nn.loss import CrossEntropyLoss from colossalai.trainer import Trainer, hooks from colossalai.utils import free_port, get_dataloader from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use from model_zoo.vit import vit_tiny_patch4_32 from torchvision import transforms from torchvision.datasets import CIFAR10 @@ -23,9 +23,10 @@ from torchvision.datasets import CIFAR10 BATCH_SIZE = 4 NUM_EPOCHS = 60 WARMUP_EPOCHS = 5 -CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')), - fp16=dict(mode=AMP_TYPE.NAIVE), - gradient_accumulation=2) +CONFIG = dict(NUM_MICRO_BATCHES=2, + parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')), + fp16=dict(mode=AMP_TYPE.NAIVE), + gradient_accumulation=2) def run_trainer(rank, world_size, port): @@ -79,7 +80,7 @@ def run_trainer(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_hybrid_parallel(): world_size = 8 run_func = partial(run_trainer, world_size=world_size, port=free_port()) diff --git a/tests/test_engine/test_engine.py b/tests/test_engine/test_engine.py index 19dbaa0eb..fb5bd1e16 100644 --- a/tests/test_engine/test_engine.py +++ b/tests/test_engine/test_engine.py @@ -7,7 +7,7 @@ from colossalai.amp import AMP_TYPE from colossalai.core import global_context as gpc from colossalai.utils import free_port from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.testing import parameterize, rerun_on_exception +from colossalai.testing import parameterize, rerun_if_address_is_in_use CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), @@ -56,7 +56,7 @@ def run_engine(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_engine(): world_size = 2 run_func = partial(run_engine, world_size=world_size, port=free_port()) diff --git a/tests/test_layers/test_1d/test_1d.py b/tests/test_layers/test_1d/test_1d.py index 4ee2b9419..cbdcb1b72 100644 --- a/tests/test_layers/test_1d/test_1d.py +++ b/tests/test_layers/test_1d/test_1d.py @@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers from colossalai.initialize import launch from colossalai.utils import free_port -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use from checks_1d.check_layer_1d import * CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),) @@ -35,7 +35,7 @@ def check_layer(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_1d(): world_size = 4 run_func = partial(check_layer, world_size=world_size, port=free_port()) diff --git a/tests/test_layers/test_2d/test_2d.py b/tests/test_layers/test_2d/test_2d.py index 0b72e3583..da235d0cf 100644 --- a/tests/test_layers/test_2d/test_2d.py +++ b/tests/test_layers/test_2d/test_2d.py @@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.utils import free_port -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight, check_embed, check_layernorm, check_linear, check_loss, check_patch_embed, check_vocab_parallel_classifier_given_embed_weight, @@ -55,7 +55,7 @@ def check_layer_and_operation(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_2d(): world_size = 4 run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) diff --git a/tests/test_layers/test_2p5d/test_2p5d.py b/tests/test_layers/test_2p5d/test_2p5d.py index eef5f5beb..365e2d934 100644 --- a/tests/test_layers/test_2p5d/test_2p5d.py +++ b/tests/test_layers/test_2p5d/test_2p5d.py @@ -7,7 +7,7 @@ from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.utils import free_port -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use from checks_2p5d.check_layer_2p5d import * from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB @@ -51,7 +51,7 @@ def check_layer_and_operation(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_2p5d(): world_size = 4 run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) diff --git a/tests/test_layers/test_3d/test_3d.py b/tests/test_layers/test_3d/test_3d.py index c19d550cc..7962d7dca 100644 --- a/tests/test_layers/test_3d/test_3d.py +++ b/tests/test_layers/test_3d/test_3d.py @@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.utils import free_port -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight, check_embed, check_layernorm, check_linear, check_loss, check_patch_embed, check_vocab_parallel_classifier_given_embed_weight, @@ -51,7 +51,7 @@ def check_layer_and_operation(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_3d(): world_size = 8 run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) diff --git a/tests/test_layers/test_sequence/test_sequence.py b/tests/test_layers/test_sequence/test_sequence.py index b06d99870..3862c4ccd 100644 --- a/tests/test_layers/test_sequence/test_sequence.py +++ b/tests/test_layers/test_sequence/test_sequence.py @@ -7,7 +7,7 @@ import pytest from colossalai.core import global_context as gpc from colossalai.context import ParallelMode -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use from functools import partial CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence'))) @@ -132,7 +132,7 @@ def run_test(rank, world_size): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_sequence(): world_size = 4 run_func = partial(run_test, world_size=world_size) diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 462e2694e..b2770f64d 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -10,8 +10,7 @@ from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, from colossalai.context.moe_context import MOE_CONTEXT from colossalai.utils.moe import sync_moe_model_param from colossalai.engine.gradient_handler import MoeGradientHandler -from colossalai.testing import assert_equal_in_group -from colossalai.testing import rerun_on_exception +from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use BATCH_SIZE = 4 DIM = 16 @@ -63,7 +62,7 @@ def run_test(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_grad_handler(): world_size = 4 run_func = partial(run_test, world_size=world_size, port=free_port()) diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 3e18ab1f6..e5b5aa68d 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc from colossalai.utils import free_port, get_current_device from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use BATCH_SIZE = 16 NUM_EXPERTS = 4 @@ -87,7 +87,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f @pytest.mark.parametrize("hidden_size", [32, 144]) @pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) @pytest.mark.parametrize("router", [Top1Router, Top2Router]) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_moe_kernel(rs, hidden_size, data_type, router): world_size = 4 run_func = partial(run_routing, diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 03d81df76..3126f59e2 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -8,7 +8,7 @@ from colossalai.utils import free_port, get_current_device from colossalai.nn.layer.moe import Experts from colossalai.context.moe_context import MOE_CONTEXT from colossalai.utils.moe import sync_moe_model_param -from colossalai.testing import assert_equal_in_group, rerun_on_exception +from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use D_MODEL = 4 D_FF = 8 @@ -60,7 +60,7 @@ def run_test(rank, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_moe_initialization(): world_size = 4 run_func = partial(run_test, port=free_port()) diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 50963e641..b6bc08006 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -14,7 +14,7 @@ from colossalai.nn.layer import MoeModule from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import get_current_device from tests.test_zero.common import CONFIG @@ -91,7 +91,7 @@ def _run_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [2, 4]) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_moe_zero_init(world_size): run_func = partial(_run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index 945f8ba3c..778bf6d26 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -4,7 +4,7 @@ import colossalai import pytest import torch import torch.multiprocessing as mp -from colossalai.testing import parameterize, rerun_on_exception +from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) @@ -65,7 +65,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [2]) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_moe_zero_model(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index b0562da7c..08a36cb36 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -6,7 +6,7 @@ import torch import torch.multiprocessing as mp from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import CPUAdam -from colossalai.testing import parameterize, rerun_on_exception +from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) @@ -120,7 +120,7 @@ def _run_dist(rank, world_size, port): # use_cpuadam = True can be used with cpu_offload = False @pytest.mark.dist @pytest.mark.parametrize("world_size", [2]) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_moe_zero_optim(world_size): run_func = partial(_run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py index 54b6eba6e..b01343329 100644 --- a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py +++ b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py @@ -9,7 +9,7 @@ from colossalai.logging import get_dist_logger from colossalai.trainer import Trainer from colossalai.utils import MultiTimer, free_port from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.testing import parameterize, rerun_on_exception +from colossalai.testing import parameterize, rerun_if_address_is_in_use BATCH_SIZE = 4 IMG_SIZE = 32 @@ -51,7 +51,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_trainer_no_pipeline(): world_size = 4 run_func = partial(run_dist, world_size=world_size, port=free_port()) diff --git a/tests/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_trainer/test_trainer_with_pipe_schedule.py index 593c572e2..3698526a8 100644 --- a/tests/test_trainer/test_trainer_with_pipe_schedule.py +++ b/tests/test_trainer/test_trainer_with_pipe_schedule.py @@ -17,13 +17,16 @@ from torch.optim import Adam from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.models import resnet18 -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use BATCH_SIZE = 4 IMG_SIZE = 32 NUM_EPOCHS = 200 -CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2),) +CONFIG = dict( + NUM_MICRO_BATCHES=2, + parallel=dict(pipeline=2), +) def run_trainer_with_pipeline(rank, world_size, port): @@ -85,7 +88,7 @@ def run_trainer_with_pipeline(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_trainer_with_pipeline(): world_size = 4 run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port()) diff --git a/tests/test_utils/test_commons.py b/tests/test_utils/test_commons.py index d588f42d9..a193d9d12 100644 --- a/tests/test_utils/test_commons.py +++ b/tests/test_utils/test_commons.py @@ -1,6 +1,6 @@ from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline from colossalai.utils import free_port -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use from colossalai.zero.sharded_param import ShardedTensor import colossalai @@ -35,7 +35,7 @@ def run_tensor_move(rank): assert (tgt_t.device.type == 'cpu') -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_tensor_move(): mp.spawn(run_tensor_move, nprocs=1) diff --git a/tests/test_utils/test_gradient_accumluation.py b/tests/test_utils/test_gradient_accumluation.py index eddfc421e..7f5ee47be 100644 --- a/tests/test_utils/test_gradient_accumluation.py +++ b/tests/test_utils/test_gradient_accumluation.py @@ -3,6 +3,7 @@ from functools import partial from pathlib import Path import colossalai +from colossalai.testing.utils import rerun_if_address_is_in_use import pytest import torch import torch.multiprocessing as mp @@ -10,7 +11,7 @@ import torch.nn as nn from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.utils import free_port, get_dataloader -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use from torch.optim import Adam from torchvision import transforms from torchvision.datasets import CIFAR10 @@ -87,7 +88,7 @@ def run_no_pipeline(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_engine(): world_size = 4 func = partial(run_no_pipeline, world_size=world_size, port=free_port()) diff --git a/tests/test_utils/test_zero_gradient_clippling.py b/tests/test_utils/test_zero_gradient_clippling.py index 3f299afe8..8bdae8846 100644 --- a/tests/test_utils/test_zero_gradient_clippling.py +++ b/tests/test_utils/test_zero_gradient_clippling.py @@ -16,7 +16,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy from functools import partial -from colossalai.testing import parameterize, rerun_on_exception +from colossalai.testing import parameterize, rerun_if_address_is_in_use def checkpoint_wrapper(module, enable=True): @@ -102,7 +102,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_zero_clip_grad(): world_size = 4 run_func = partial(run_dist, world_size=world_size, port=free_port()) diff --git a/tests/test_zero/test_found_inf.py b/tests/test_zero/test_found_inf.py index 897038355..34283f501 100644 --- a/tests/test_zero/test_found_inf.py +++ b/tests/test_zero/test_found_inf.py @@ -6,7 +6,7 @@ import pytest import torch import torch.multiprocessing as mp from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_on_exception +from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import BucketTensorShardStrategy @@ -62,7 +62,7 @@ def _run_dist(rank, world_size, port): # use_cpuadam = True can be used with cpu_offload = False @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_found_inf(world_size): run_func = partial(_run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_zero/test_init_context.py b/tests/test_zero/test_init_context.py index cbbc6a7f3..c3eb5067b 100644 --- a/tests/test_zero/test_init_context.py +++ b/tests/test_zero/test_init_context.py @@ -8,7 +8,7 @@ import pytest import torch import torch.multiprocessing as mp from colossalai.logging import get_dist_logger -from colossalai.testing import parameterize, rerun_on_exception +from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device from colossalai.utils.memory_tracer.model_data_memtracer import \ @@ -64,7 +64,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 4]) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_zero_init_context(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_zero/test_mem_collector.py b/tests/test_zero/test_mem_collector.py index 62c367701..d528bdfc6 100644 --- a/tests/test_zero/test_mem_collector.py +++ b/tests/test_zero/test_mem_collector.py @@ -10,7 +10,7 @@ from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.shard_utils import BucketTensorShardStrategy from colossalai.utils import free_port -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use from functools import partial @@ -64,7 +64,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_mem_collector(world_size=2): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_zero/test_shard_model_v2.py b/tests/test_zero/test_shard_model_v2.py index 1f46883a8..654c82a46 100644 --- a/tests/test_zero/test_shard_model_v2.py +++ b/tests/test_zero/test_shard_model_v2.py @@ -7,7 +7,7 @@ import colossalai import pytest import torch import torch.multiprocessing as mp -from colossalai.testing import parameterize, rerun_on_exception +from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) @@ -59,7 +59,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_shard_model_v2(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_zero/test_shard_param.py b/tests/test_zero/test_shard_param.py index 91c669af3..4df5f3400 100644 --- a/tests/test_zero/test_shard_param.py +++ b/tests/test_zero/test_shard_param.py @@ -5,12 +5,11 @@ import colossalai import pytest import torch import torch.multiprocessing as mp -from colossalai.testing import parameterize +from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 -from colossalai.testing import rerun_on_exception from tests.test_zero.common import CONFIG, allclose from colossalai.zero.sharded_param.tensorful_state import StatefulTensor @@ -37,7 +36,7 @@ def _run_shard_tensor(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_shard_tensor(world_size): run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) @@ -85,7 +84,7 @@ def _run_shard_param_v2(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_shard_param_v2(world_size): run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_zero/test_sharded_optim_v2.py b/tests/test_zero/test_sharded_optim_v2.py index 2e94df7de..2b42a7128 100644 --- a/tests/test_zero/test_sharded_optim_v2.py +++ b/tests/test_zero/test_sharded_optim_v2.py @@ -8,7 +8,7 @@ import torch.distributed as dist import torch.multiprocessing as mp from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import CPUAdam -from colossalai.testing import parameterize, rerun_on_exception +from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) @@ -105,7 +105,7 @@ def _run_dist(rank, world_size, port): # use_cpuadam = True can be used with cpu_offload = False @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_sharded_optim_v2(world_size): run_func = partial(_run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_zero/test_sharded_optim_with_sync_bn.py b/tests/test_zero/test_sharded_optim_with_sync_bn.py index 7a52e437a..ea5b31518 100644 --- a/tests/test_zero/test_sharded_optim_with_sync_bn.py +++ b/tests/test_zero/test_sharded_optim_with_sync_bn.py @@ -10,7 +10,7 @@ import torch.distributed as dist import torch.multiprocessing as mp from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import TensorShardStrategy @@ -71,7 +71,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_sharded_optim_with_sync_bn(): """ This test is to make sure that buffers are synchronized between ranks diff --git a/tests/test_zero/test_state_dict.py b/tests/test_zero/test_state_dict.py index 818d05ffd..188bc5968 100644 --- a/tests/test_zero/test_state_dict.py +++ b/tests/test_zero/test_state_dict.py @@ -8,7 +8,7 @@ import colossalai import pytest import torch import torch.multiprocessing as mp -from colossalai.testing import parameterize, rerun_on_exception +from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) @@ -49,7 +49,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_zero_state_dict(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_zero/test_stateful_tensor_mgr.py b/tests/test_zero/test_stateful_tensor_mgr.py index 6449c285f..93ef3af8e 100644 --- a/tests/test_zero/test_stateful_tensor_mgr.py +++ b/tests/test_zero/test_stateful_tensor_mgr.py @@ -10,7 +10,7 @@ from colossalai.gemini import StatefulTensorMgr from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.tensorful_state import TensorState from colossalai.utils import free_port -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use from torch.nn.parameter import Parameter from typing import List from functools import partial @@ -120,8 +120,8 @@ def run_dist(rank, world_size, port): run_stm() -@pytest.mark.gpu -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@pytest.mark.dist +@rerun_if_address_is_in_use() def test_stateful_tensor_manager(world_size=1): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_zero/test_tensor_utils.py b/tests/test_zero/test_tensor_utils.py index d3b50d2dc..0b4201fe5 100644 --- a/tests/test_zero/test_tensor_utils.py +++ b/tests/test_zero/test_tensor_utils.py @@ -6,6 +6,7 @@ from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, colo_model_tensor_clone) from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use import torch @@ -84,6 +85,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [4, 5]) +@rerun_if_address_is_in_use() def test_zero_tensor_utils(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_zero/test_zero_engine.py b/tests/test_zero/test_zero_engine.py index 82e0f69ca..08910e8f5 100644 --- a/tests/test_zero/test_zero_engine.py +++ b/tests/test_zero/test_zero_engine.py @@ -9,7 +9,7 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp from colossalai.core import global_context as gpc -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.sharded_model.utils import col_model_deepcopy @@ -96,7 +96,7 @@ def run_dist(rank, world_size, port, parallel_config): @pytest.mark.skip @pytest.mark.dist @pytest.mark.parametrize("world_size", [2, 4]) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_mp_engine(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG) mp.spawn(run_func, nprocs=world_size) @@ -104,7 +104,7 @@ def test_mp_engine(world_size): @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_zero_engine(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG) mp.spawn(run_func, nprocs=world_size)