mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[Inference] Add the logic of the inference engine (#5173)
* add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and struct * Add the logic of the inference engine * update engine and test * Recover cache_manager.py * add logger * fix conflict * update codes * update codes * update model and tokenizer * fix add the logic about shardformer * change kvcache_manager docstring * add policy * fix ci bug in test_kvcache_manager.py * remove codes related o tokenizer and move model_policy * fix code style * add ordered_set to requirements-infer.txt * Delete extra empty lines * add ordered_set to requirements-test.txt
This commit is contained in:
committed by
FrankLeeeee
parent
93aeacca34
commit
8daee26989
18
tests/test_infer/test_kvcache_manager.py
Normal file → Executable file
18
tests/test_infer/test_kvcache_manager.py
Normal file → Executable file
@@ -1,12 +1,14 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers.models.llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.testing import parameterize, spawn
|
||||
|
||||
|
||||
@parameterize(
|
||||
@@ -64,7 +66,7 @@ def test_logical_blocks(test_config):
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_cache_manager(test_config):
|
||||
def check_cache_manager(test_config):
|
||||
disable_existing_loggers()
|
||||
|
||||
assert test_config["max_batch_size"] > 1
|
||||
@@ -78,7 +80,7 @@ def test_cache_manager(test_config):
|
||||
max_input_length = test_config["max_input_len"]
|
||||
max_output_length = test_config["max_output_len"]
|
||||
|
||||
inference_config = InferenceConfig(model="", **test_config)
|
||||
inference_config = InferenceConfig(**test_config)
|
||||
model_config = LlamaConfig(
|
||||
hidden_size=hidden_size,
|
||||
num_hidden_layers=num_layers,
|
||||
@@ -147,6 +149,16 @@ def test_cache_manager(test_config):
|
||||
assert cache_manager.get_num_available_blocks() == num_blocks
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_cache_manager()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_cache_manager():
|
||||
spawn(run_dist, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_logical_blocks()
|
||||
test_cache_manager()
|
||||
|
Reference in New Issue
Block a user