diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 2952dcc81..902745e9e 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -54,6 +54,7 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict): @rerun_if_address_is_in_use() +@clear_cache_before_run() @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("model_init_func", [single_chunk_init, multi_chunk_init]) @@ -114,7 +115,6 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@clear_cache_before_run() @pytest.mark.parametrize("world_size", [1, 4]) def test_inference(world_size): spawn(run_dist, world_size)