mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-02 08:16:48 +00:00
fix
This commit is contained in:
parent
99298c6a6d
commit
eaef783ec3
@ -18,6 +18,7 @@ from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
generate_caches_and_block_tables_vllm,
|
||||
torch_attn_ref,
|
||||
)
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
q_len = 1
|
||||
PARTITION_SIZE = 512
|
||||
@ -55,7 +56,7 @@ def numpy_allclose(x, y, rtol, atol):
|
||||
|
||||
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
|
||||
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
|
||||
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32, 256, 512])
|
||||
@ -196,7 +197,7 @@ except ImportError:
|
||||
HAS_VLLM = False
|
||||
print("The subsequent test requires vllm. Please refer to https://github.com/vllm-project/vllm")
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@pytest.mark.skipif(not HAS_VLLM, reason="requires vllm")
|
||||
@pytest.mark.parametrize("BATCH_SIZE", [1, 7, 32])
|
||||
@pytest.mark.parametrize("BLOCK_SIZE", [6, 32])
|
||||
|
Loading…
Reference in New Issue
Block a user