mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[Inference] Add the logic of the inference engine (#5173)
* add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and struct * Add the logic of the inference engine * update engine and test * Recover cache_manager.py * add logger * fix conflict * update codes * update codes * update model and tokenizer * fix add the logic about shardformer * change kvcache_manager docstring * add policy * fix ci bug in test_kvcache_manager.py * remove codes related o tokenizer and move model_policy * fix code style * add ordered_set to requirements-infer.txt * Delete extra empty lines * add ordered_set to requirements-test.txt
This commit is contained in:
committed by
FrankLeeeee
parent
93aeacca34
commit
8daee26989
44
tests/test_infer/test_inference_engine.py
Executable file
44
tests/test_infer/test_inference_engine.py
Executable file
@@ -0,0 +1,44 @@
|
||||
import pytest
|
||||
import transformers
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
def check_inference_engine():
|
||||
model = transformers.LlamaForCausalLM(
|
||||
transformers.LlamaConfig(
|
||||
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
|
||||
)
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
inference_config = InferenceConfig()
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
|
||||
inputs = [
|
||||
"介绍一下北京",
|
||||
"介绍一下武汉",
|
||||
]
|
||||
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
outputs = inference_engine.generate(None)
|
||||
|
||||
for s1, s2 in zip(inputs, outputs):
|
||||
assert s1 == s2
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_inference_engine()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_inference_engine():
|
||||
spawn(run_dist, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_inference_engine()
|
Reference in New Issue
Block a user