From 964f9a7974b59fe72c1fdcce46472530d604d5c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Apr 2025 02:20:40 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../test_kernels/cuda/test_flash_decoding_attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py index c4267d49f..d656c4834 100644 --- a/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py +++ b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py @@ -11,6 +11,7 @@ from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generat inference_ops = InferenceOpsLoader().load() +from colossalai.testing import clear_cache_before_run from tests.test_infer.test_kernels.triton.kernel_utils import ( convert_kv_unpad_to_padded, create_attention_mask, @@ -18,7 +19,6 @@ from tests.test_infer.test_kernels.triton.kernel_utils import ( generate_caches_and_block_tables_vllm, torch_attn_ref, ) -from colossalai.testing import clear_cache_before_run q_len = 1 PARTITION_SIZE = 512 @@ -56,6 +56,7 @@ def numpy_allclose(x, y, rtol, atol): np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol) + @clear_cache_before_run() @pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32]) @pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32]) @@ -197,6 +198,7 @@ except ImportError: HAS_VLLM = False print("The subsequent test requires vllm. Please refer to https://github.com/vllm-project/vllm") + @clear_cache_before_run() @pytest.mark.skipif(not HAS_VLLM, reason="requires vllm") @pytest.mark.parametrize("BATCH_SIZE", [1, 7, 32])