mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 06:52:46 +00:00
Merge branch 'upgrade-transformers' of github.com:flybird11111/ColossalAI into upgrade-transformers
This commit is contained in:
commit
5c56a7fd7b
@ -11,6 +11,7 @@ from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generat
|
|||||||
|
|
||||||
inference_ops = InferenceOpsLoader().load()
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
|
from colossalai.testing import clear_cache_before_run
|
||||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||||
convert_kv_unpad_to_padded,
|
convert_kv_unpad_to_padded,
|
||||||
create_attention_mask,
|
create_attention_mask,
|
||||||
@ -18,7 +19,6 @@ from tests.test_infer.test_kernels.triton.kernel_utils import (
|
|||||||
generate_caches_and_block_tables_vllm,
|
generate_caches_and_block_tables_vllm,
|
||||||
torch_attn_ref,
|
torch_attn_ref,
|
||||||
)
|
)
|
||||||
from colossalai.testing import clear_cache_before_run
|
|
||||||
|
|
||||||
q_len = 1
|
q_len = 1
|
||||||
PARTITION_SIZE = 512
|
PARTITION_SIZE = 512
|
||||||
@ -56,6 +56,7 @@ def numpy_allclose(x, y, rtol, atol):
|
|||||||
|
|
||||||
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
|
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
|
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
|
||||||
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
|
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
|
||||||
@ -197,6 +198,7 @@ except ImportError:
|
|||||||
HAS_VLLM = False
|
HAS_VLLM = False
|
||||||
print("The subsequent test requires vllm. Please refer to https://github.com/vllm-project/vllm")
|
print("The subsequent test requires vllm. Please refer to https://github.com/vllm-project/vllm")
|
||||||
|
|
||||||
|
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
@pytest.mark.skipif(not HAS_VLLM, reason="requires vllm")
|
@pytest.mark.skipif(not HAS_VLLM, reason="requires vllm")
|
||||||
@pytest.mark.parametrize("BATCH_SIZE", [1, 7, 32])
|
@pytest.mark.parametrize("BATCH_SIZE", [1, 7, 32])
|
||||||
|
Loading…
Reference in New Issue
Block a user