[Inference] Add the logic of the inference engine (#5173)

* add infer_struct and infer_config

* update codes

* change InferConfig

* Add hf_model_config to the engine

* rm _get_hf_model_config

* update codes

* made adjustments according to the feedback from the reviewer.

* update codes

* add ci test for config and struct

* Add the logic of the inference engine

* update engine and test

* Recover cache_manager.py

* add logger

* fix conflict

* update codes

* update codes

* update model and tokenizer

* fix add the logic about shardformer

* change kvcache_manager docstring

* add policy

* fix ci bug in test_kvcache_manager.py

* remove codes related o tokenizer and move model_policy

* fix  code style

* add ordered_set to requirements-infer.txt

* Delete extra empty lines

* add ordered_set to requirements-test.txt
This commit is contained in:
yuehuayingxueluo
2023-12-18 10:40:47 +08:00
committed by FrankLeeeee
parent 93aeacca34
commit 8daee26989
13 changed files with 555 additions and 172 deletions

View File

@@ -1,5 +1,7 @@
from typing import List
from colossalai.inference.struct import BatchInfo, Sequence
class RequestHandler:
"""
@@ -7,14 +9,17 @@ class RequestHandler:
During generation process, we call schedule function each iteration to update current batch.
Args:
cache_config: Configuration for initialize and manage kv cache.
inference_config: Store the configuration information related to inference.
model_config: The huggingface model config.
"""
def __init__(self, cache_config) -> None:
self.cache_config = cache_config
def __init__(self, inference_config, model_config) -> None:
self.inference_config = inference_config
self.model_config = model_config
self._init_cache()
self.waiting_list: List["Reqseq"] = []
self.running_list: List["Reqseq"] = []
self.waiting_list: List["Sequence"] = []
self.running_list: List["Sequence"] = []
self.batch = BatchInfo.init_batch()
def _init_cache(self):
"""
@@ -25,12 +30,17 @@ class RequestHandler:
"""
The main logic of request handler.
"""
# The code below is only used for testing engine and will be modified.
if self.waiting_list:
self.running_list = self.waiting_list
self.batch.add_seqs(self.running_list)
return self.batch
def add_sequence(self, reqseq: "Reqseq"):
def add_sequence(self, req_seq: "Sequence"):
"""
Add the request to waiting list.
"""
self.waiting_list.append(reqseq)
self.waiting_list.append(req_seq)
def abort_sequence(self, seq_id: str):
"""
@@ -39,10 +49,23 @@ class RequestHandler:
self._find_sequence(seq_id)
return
def _find_sequence(self, seq_id: str) -> "Reqseq":
def _find_sequence(self, seq_id: str) -> "Sequence":
"""
Find the request by seq_id.
"""
def check_unfinished_seqs(self) -> bool:
return self.waiting_list or self.running_list
return len(self.waiting_list) != 0 or len(self.running_list) != 0
def update(self):
"""
Update the waiting list and running list.
"""
# The code below is only used for testing engine and will be modified.
self.waiting_list = []
self.running_list = []
finished_sequences = list(self.batch.sequences_set)
self.batch.clear_batch()
return finished_sequences