mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Inference] ADD async and sync Api server using FastAPI (#5396)
* add api server * fix * add * add completion service and fix bug * add generation config * revise shardformer * fix bugs * add docstrings and fix some bugs * fix bugs and add choices for prompt template
This commit is contained in:
@@ -263,24 +263,27 @@ class RequestHandler:
|
||||
), f"Sequence {req.request_id} exceeds input length limit"
|
||||
self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req)
|
||||
|
||||
def abort_sequence(self, request_id: str):
|
||||
def abort_sequence(self, request_id: int):
|
||||
"""
|
||||
Abort the request.
|
||||
"""
|
||||
seq, priority = self._find_sequence(request_id)
|
||||
if seq.status == RequestStatus.WAITING:
|
||||
seq.mark_aborted()
|
||||
self.waiting_list[priority].remove(seq)
|
||||
elif seq.status.is_running():
|
||||
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
|
||||
self.running_list.remove(seq)
|
||||
else:
|
||||
try:
|
||||
self.done_list.remove(seq)
|
||||
except:
|
||||
return
|
||||
result = self._find_sequence(request_id)
|
||||
if result is not None:
|
||||
seq, priority = result
|
||||
if seq.status == RequestStatus.WAITING:
|
||||
seq.mark_aborted()
|
||||
self.waiting_list[priority].remove(seq)
|
||||
elif seq.status.is_running():
|
||||
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
|
||||
self.running_list.remove(seq)
|
||||
else:
|
||||
try:
|
||||
self.done_list.remove(seq)
|
||||
except:
|
||||
return
|
||||
return
|
||||
|
||||
def _find_sequence(self, request_id: str) -> Sequence:
|
||||
def _find_sequence(self, request_id: int) -> Sequence:
|
||||
"""
|
||||
Find the request by request_id.
|
||||
"""
|
||||
@@ -324,6 +327,9 @@ class RequestHandler:
|
||||
def check_unfinished_seqs(self) -> bool:
|
||||
return self._has_waiting() or not self.running_list.is_empty()
|
||||
|
||||
def current_requests_in_batch(self) -> int:
|
||||
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
||||
|
||||
def search_tokens(self, generation_config: GenerationConfig, logits):
|
||||
"""
|
||||
Sample tokens for finished requests.
|
||||
|
Reference in New Issue
Block a user