mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Inference] Add the logic of the inference engine (#5173)
* add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and struct * Add the logic of the inference engine * update engine and test * Recover cache_manager.py * add logger * fix conflict * update codes * update codes * update model and tokenizer * fix add the logic about shardformer * change kvcache_manager docstring * add policy * fix ci bug in test_kvcache_manager.py * remove codes related o tokenizer and move model_policy * fix code style * add ordered_set to requirements-infer.txt * Delete extra empty lines * add ordered_set to requirements-test.txt
This commit is contained in:
committed by
FrankLeeeee
parent
93aeacca34
commit
8daee26989
70
tests/test_infer/test_config_and_struct.py
Normal file → Executable file
70
tests/test_infer/test_config_and_struct.py
Normal file → Executable file
@@ -1,26 +1,45 @@
|
||||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.struct import BatchInfo, RequsetStatus, Sequence
|
||||
from colossalai.inference.struct import BatchInfo, Sequence
|
||||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
def test_config_and_inferenceData():
|
||||
config = InferenceConfig("/llama")
|
||||
assert config.max_batch_size
|
||||
def check_config_and_inference():
|
||||
config = InferenceConfig()
|
||||
assert config.max_batch_size == 8
|
||||
sequence = Sequence(
|
||||
request_id=1,
|
||||
prompt="abc",
|
||||
token_id=[1, 2, 3],
|
||||
input_token_id=[1, 2, 3],
|
||||
block_size=16,
|
||||
sample_params=None,
|
||||
block_table_index=1,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
|
||||
sequence2 = Sequence(
|
||||
request_id=2,
|
||||
prompt="bcd",
|
||||
token_id=[4, 5, 6],
|
||||
input_token_id=[4, 5, 6],
|
||||
block_size=16,
|
||||
sample_params=None,
|
||||
block_table_index=2,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
|
||||
sequence3 = Sequence(
|
||||
request_id=3,
|
||||
prompt="efg",
|
||||
input_token_id=[7, 8, 9],
|
||||
block_size=16,
|
||||
sample_params=None,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
|
||||
assert sequence.get_sentence_len() == 3
|
||||
@@ -29,15 +48,34 @@ def test_config_and_inferenceData():
|
||||
assert sequence.check_finish() == False
|
||||
|
||||
batch = BatchInfo.init_batch([sequence])
|
||||
assert batch.block_table[sequence.request_id] == sequence.block_table_index
|
||||
sequence.status = RequsetStatus.COMPLETED
|
||||
batch.fliter_batch()
|
||||
assert batch.block_table == {}
|
||||
batch.add_seqs([sequence2])
|
||||
assert batch.block_table[sequence2.request_id] == sequence2.block_table_index
|
||||
batch.add_seqs([sequence2, sequence3])
|
||||
batch.add_seqs([sequence])
|
||||
|
||||
assert batch.is_empty() == False
|
||||
assert batch.get_batch_size() == 3
|
||||
batch.update_batch_tokens([1, 2, 3])
|
||||
seq = batch.abort_seq(sequence)
|
||||
seq2 = batch.fliter_batch()[0]
|
||||
|
||||
assert batch.get_batch_size() == 1
|
||||
assert seq.get_output_len() == 1
|
||||
assert seq.output_token_id == [1]
|
||||
assert seq2.get_output_len() == 1
|
||||
assert seq2.output_token_id == [2]
|
||||
|
||||
batch.clear_batch()
|
||||
assert batch.block_table == {}
|
||||
assert batch.is_empty() == True
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_config_and_inference()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_config_and_inference():
|
||||
spawn(run_dist, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_config_and_inferenceData()
|
||||
test_config_and_inference()
|
||||
|
Reference in New Issue
Block a user