[Fix] Fix Inference Example, Tests, and Requirements (#5688)

* clean requirements

* modify example inference struct

* add test ci scripts

* mark test_infer as submodule

* rm deprecated cls & deps

* import of HAS_FLASH_ATTN

* prune inference tests to be run

* prune triton kernel tests

* increment pytest timeout mins

* revert import path in openmoe
This commit is contained in:
Yuanheng Zhao
2024-05-08 11:30:15 +08:00
committed by GitHub
parent f9afe0addd
commit 55cc7f3df7
23 changed files with 46 additions and 328 deletions

View File

@@ -1,3 +1,4 @@
import pytest
import torch
from transformers.cache_utils import DynamicCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
@@ -7,6 +8,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
@pytest.mark.skip(reason="This test is not used in the current version.")
def test_copy_to_cache():
key = torch.ones((2, 11, 3, 3))
key[0, 9, :, :] = 0
@@ -24,6 +26,7 @@ def test_copy_to_cache():
assert cache[3, 0, 0, 0] == 1
@pytest.mark.skip(reason="This test is not used in the current version.")
def test_convert_kvcache():
cache = torch.ones(8, 3, 8, 3)
key = torch.ones(2, 1, 3, 3) + 1
@@ -34,6 +37,7 @@ def test_convert_kvcache():
assert converted_cache.shape == (2, 10, 3, 3)
@pytest.mark.skip(reason="This test is not used in the current version.")
def test_context_attention():
"""
test config: head_num = 4, head_size = 4
@@ -86,6 +90,7 @@ def test_context_attention():
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-3)
@pytest.mark.skip(reason="This test is not used in the current version.")
def test_decoding_attention():
# test the pipeline of decoding attention
attn = PagedAttention()