From 559f15a4c9b25e5bee2bf7a753a847a72565c018 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 26 May 2025 18:10:57 +0800 Subject: [PATCH] fix (#6328) --- tests/test_zero/test_gemini/test_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 2952dcc81..902745e9e 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -54,6 +54,7 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict): @rerun_if_address_is_in_use() +@clear_cache_before_run() @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("model_init_func", [single_chunk_init, multi_chunk_init]) @@ -114,7 +115,6 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@clear_cache_before_run() @pytest.mark.parametrize("world_size", [1, 4]) def test_inference(world_size): spawn(run_dist, world_size)