mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[Fix] Fix Inference Example, Tests, and Requirements (#5688)
* clean requirements * modify example inference struct * add test ci scripts * mark test_infer as submodule * rm deprecated cls & deps * import of HAS_FLASH_ATTN * prune inference tests to be run * prune triton kernel tests * increment pytest timeout mins * revert import path in openmoe
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
import torch
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
@@ -7,6 +8,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
|
||||
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is not used in the current version.")
|
||||
def test_copy_to_cache():
|
||||
key = torch.ones((2, 11, 3, 3))
|
||||
key[0, 9, :, :] = 0
|
||||
@@ -24,6 +26,7 @@ def test_copy_to_cache():
|
||||
assert cache[3, 0, 0, 0] == 1
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is not used in the current version.")
|
||||
def test_convert_kvcache():
|
||||
cache = torch.ones(8, 3, 8, 3)
|
||||
key = torch.ones(2, 1, 3, 3) + 1
|
||||
@@ -34,6 +37,7 @@ def test_convert_kvcache():
|
||||
assert converted_cache.shape == (2, 10, 3, 3)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is not used in the current version.")
|
||||
def test_context_attention():
|
||||
"""
|
||||
test config: head_num = 4, head_size = 4
|
||||
@@ -86,6 +90,7 @@ def test_context_attention():
|
||||
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is not used in the current version.")
|
||||
def test_decoding_attention():
|
||||
# test the pipeline of decoding attention
|
||||
attn = PagedAttention()
|
||||
|
Reference in New Issue
Block a user