[Inference] Add the logic of the inference engine (#5173)

* add infer_struct and infer_config

* update codes

* change InferConfig

* Add hf_model_config to the engine

* rm _get_hf_model_config

* update codes

* made adjustments according to the feedback from the reviewer.

* update codes

* add ci test for config and struct

* Add the logic of the inference engine

* update engine and test

* Recover cache_manager.py

* add logger

* fix conflict

* update codes

* update codes

* update model and tokenizer

* fix add the logic about shardformer

* change kvcache_manager docstring

* add policy

* fix ci bug in test_kvcache_manager.py

* remove codes related o tokenizer and move model_policy

* fix  code style

* add ordered_set to requirements-infer.txt

* Delete extra empty lines

* add ordered_set to requirements-test.txt
This commit is contained in:
yuehuayingxueluo
2023-12-18 10:40:47 +08:00
committed by FrankLeeeee
parent 93aeacca34
commit 8daee26989
13 changed files with 555 additions and 172 deletions

View File

@@ -0,0 +1,44 @@
import pytest
import transformers
from transformers import AutoTokenizer
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import spawn
def check_inference_engine():
model = transformers.LlamaForCausalLM(
transformers.LlamaConfig(
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
)
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
inference_config = InferenceConfig()
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inputs = [
"介绍一下北京",
"介绍一下武汉",
]
inference_engine.add_request(prompts=inputs)
outputs = inference_engine.generate(None)
for s1, s2 in zip(inputs, outputs):
assert s1 == s2
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_inference_engine()
@pytest.mark.dist
def test_inference_engine():
spawn(run_dist, 1)
if __name__ == "__main__":
test_inference_engine()