mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Inference] Add the logic of the inference engine (#5173)
* add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and struct * Add the logic of the inference engine * update engine and test * Recover cache_manager.py * add logger * fix conflict * update codes * update codes * update model and tokenizer * fix add the logic about shardformer * change kvcache_manager docstring * add policy * fix ci bug in test_kvcache_manager.py * remove codes related o tokenizer and move model_policy * fix code style * add ordered_set to requirements-infer.txt * Delete extra empty lines * add ordered_set to requirements-test.txt
This commit is contained in:
committed by
FrankLeeeee
parent
93aeacca34
commit
8daee26989
@@ -1,5 +1,7 @@
|
||||
from typing import List
|
||||
|
||||
from colossalai.inference.struct import BatchInfo, Sequence
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
"""
|
||||
@@ -7,14 +9,17 @@ class RequestHandler:
|
||||
During generation process, we call schedule function each iteration to update current batch.
|
||||
|
||||
Args:
|
||||
cache_config: Configuration for initialize and manage kv cache.
|
||||
inference_config: Store the configuration information related to inference.
|
||||
model_config: The huggingface model config.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_config) -> None:
|
||||
self.cache_config = cache_config
|
||||
def __init__(self, inference_config, model_config) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.model_config = model_config
|
||||
self._init_cache()
|
||||
self.waiting_list: List["Reqseq"] = []
|
||||
self.running_list: List["Reqseq"] = []
|
||||
self.waiting_list: List["Sequence"] = []
|
||||
self.running_list: List["Sequence"] = []
|
||||
self.batch = BatchInfo.init_batch()
|
||||
|
||||
def _init_cache(self):
|
||||
"""
|
||||
@@ -25,12 +30,17 @@ class RequestHandler:
|
||||
"""
|
||||
The main logic of request handler.
|
||||
"""
|
||||
# The code below is only used for testing engine and will be modified.
|
||||
if self.waiting_list:
|
||||
self.running_list = self.waiting_list
|
||||
self.batch.add_seqs(self.running_list)
|
||||
return self.batch
|
||||
|
||||
def add_sequence(self, reqseq: "Reqseq"):
|
||||
def add_sequence(self, req_seq: "Sequence"):
|
||||
"""
|
||||
Add the request to waiting list.
|
||||
"""
|
||||
self.waiting_list.append(reqseq)
|
||||
self.waiting_list.append(req_seq)
|
||||
|
||||
def abort_sequence(self, seq_id: str):
|
||||
"""
|
||||
@@ -39,10 +49,23 @@ class RequestHandler:
|
||||
self._find_sequence(seq_id)
|
||||
return
|
||||
|
||||
def _find_sequence(self, seq_id: str) -> "Reqseq":
|
||||
def _find_sequence(self, seq_id: str) -> "Sequence":
|
||||
"""
|
||||
Find the request by seq_id.
|
||||
"""
|
||||
|
||||
def check_unfinished_seqs(self) -> bool:
|
||||
return self.waiting_list or self.running_list
|
||||
return len(self.waiting_list) != 0 or len(self.running_list) != 0
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
Update the waiting list and running list.
|
||||
"""
|
||||
|
||||
# The code below is only used for testing engine and will be modified.
|
||||
self.waiting_list = []
|
||||
self.running_list = []
|
||||
finished_sequences = list(self.batch.sequences_set)
|
||||
|
||||
self.batch.clear_batch()
|
||||
return finished_sequences
|
||||
|
Reference in New Issue
Block a user