[Fix] Fix Inference Example, Tests, and Requirements (#5688)

* clean requirements

* modify example inference struct

* add test ci scripts

* mark test_infer as submodule

* rm deprecated cls & deps

* import of HAS_FLASH_ATTN

* prune inference tests to be run

* prune triton kernel tests

* increment pytest timeout mins

* revert import path in openmoe
This commit is contained in:
Yuanheng Zhao
2024-05-08 11:30:15 +08:00
committed by GitHub
parent f9afe0addd
commit 55cc7f3df7
23 changed files with 46 additions and 328 deletions

View File

@@ -2,7 +2,7 @@ import pytest
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.testing import rerun_if_address_is_in_use, spawn
@@ -20,27 +20,6 @@ def check_config_and_inference():
max_output_len=256,
)
sequence2 = Sequence(
request_id=2,
prompt="bcd",
input_token_id=[4, 5, 6],
block_size=16,
sample_params=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=256,
)
sequence3 = Sequence(
request_id=3,
prompt="efg",
input_token_id=[7, 8, 9],
block_size=16,
sample_params=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=256,
)
sequence.mark_running()
assert sequence.status == RequestStatus.RUNNING
sequence.recycle()
@@ -51,33 +30,6 @@ def check_config_and_inference():
assert sequence.output_len == 0
assert sequence.check_finish() == False
batch = BatchInfo(
max_batch_size=8,
kv_max_split_num=16,
num_heads=2,
head_dim=128,
)
batch.add_seqs([sequence])
batch.add_seqs([sequence2, sequence3])
# add duplicated sequence to test that it will not be counted twice
batch.add_seqs([sequence])
assert batch.is_empty == False
assert batch.get_batch_size() == 3
batch.update_batch_tokens([1, 2, 3])
seq = batch.abort_seq(sequence)
seq2 = batch.fliter_batch()[0]
assert batch.get_batch_size() == 1
assert seq.output_len == 1
assert seq.output_token_id == [1]
assert seq2.output_len == 1
assert seq2.output_token_id == [2]
batch.clear_batch()
assert batch.is_empty == True
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")