diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index e54804fc5..2952dcc81 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -10,7 +10,7 @@ import colossalai from colossalai.accelerator import get_accelerator from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run from colossalai.utils import set_seed from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration @@ -53,6 +53,7 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict): return model +@rerun_if_address_is_in_use() @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("model_init_func", [single_chunk_init, multi_chunk_init]) @@ -104,6 +105,7 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal train_iter() inference_iter() train_iter() + torch.cuda.empty_cache() def run_dist(rank, world_size, port): @@ -112,8 +114,8 @@ def run_dist(rank, world_size, port): @pytest.mark.dist +@clear_cache_before_run() @pytest.mark.parametrize("world_size", [1, 4]) -@rerun_if_address_is_in_use() def test_inference(world_size): spawn(run_dist, world_size)