[Inference] Add the logic of the inference engine (#5173)

* add infer_struct and infer_config

* update codes

* change InferConfig

* Add hf_model_config to the engine

* rm _get_hf_model_config

* update codes

* made adjustments according to the feedback from the reviewer.

* update codes

* add ci test for config and struct

* Add the logic of the inference engine

* update engine and test

* Recover cache_manager.py

* add logger

* fix conflict

* update codes

* update codes

* update model and tokenizer

* fix add the logic about shardformer

* change kvcache_manager docstring

* add policy

* fix ci bug in test_kvcache_manager.py

* remove codes related o tokenizer and move model_policy

* fix  code style

* add ordered_set to requirements-infer.txt

* Delete extra empty lines

* add ordered_set to requirements-test.txt
This commit is contained in:
yuehuayingxueluo
2023-12-18 10:40:47 +08:00
committed by FrankLeeeee
parent 93aeacca34
commit 8daee26989
13 changed files with 555 additions and 172 deletions

70
tests/test_infer/test_config_and_struct.py Normal file → Executable file
View File

@@ -1,26 +1,45 @@
import pytest
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.struct import BatchInfo, RequsetStatus, Sequence
from colossalai.inference.struct import BatchInfo, Sequence
from colossalai.testing import spawn
def test_config_and_inferenceData():
config = InferenceConfig("/llama")
assert config.max_batch_size
def check_config_and_inference():
config = InferenceConfig()
assert config.max_batch_size == 8
sequence = Sequence(
request_id=1,
prompt="abc",
token_id=[1, 2, 3],
input_token_id=[1, 2, 3],
block_size=16,
sample_params=None,
block_table_index=1,
block_table=None,
eos_token_id=2,
max_output_len=256,
)
sequence2 = Sequence(
request_id=2,
prompt="bcd",
token_id=[4, 5, 6],
input_token_id=[4, 5, 6],
block_size=16,
sample_params=None,
block_table_index=2,
block_table=None,
eos_token_id=2,
max_output_len=256,
)
sequence3 = Sequence(
request_id=3,
prompt="efg",
input_token_id=[7, 8, 9],
block_size=16,
sample_params=None,
block_table=None,
eos_token_id=2,
max_output_len=256,
)
assert sequence.get_sentence_len() == 3
@@ -29,15 +48,34 @@ def test_config_and_inferenceData():
assert sequence.check_finish() == False
batch = BatchInfo.init_batch([sequence])
assert batch.block_table[sequence.request_id] == sequence.block_table_index
sequence.status = RequsetStatus.COMPLETED
batch.fliter_batch()
assert batch.block_table == {}
batch.add_seqs([sequence2])
assert batch.block_table[sequence2.request_id] == sequence2.block_table_index
batch.add_seqs([sequence2, sequence3])
batch.add_seqs([sequence])
assert batch.is_empty() == False
assert batch.get_batch_size() == 3
batch.update_batch_tokens([1, 2, 3])
seq = batch.abort_seq(sequence)
seq2 = batch.fliter_batch()[0]
assert batch.get_batch_size() == 1
assert seq.get_output_len() == 1
assert seq.output_token_id == [1]
assert seq2.get_output_len() == 1
assert seq2.output_token_id == [2]
batch.clear_batch()
assert batch.block_table == {}
assert batch.is_empty() == True
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_config_and_inference()
@pytest.mark.dist
def test_config_and_inference():
spawn(run_dist, 1)
if __name__ == "__main__":
test_config_and_inferenceData()
test_config_and_inference()