[Inference] ADD async and sync Api server using FastAPI (#5396)

* add api server

* fix

* add

* add completion service and fix bug

* add generation config

* revise shardformer

* fix bugs

* add docstrings and fix some bugs

* fix bugs and add choices for prompt template
This commit is contained in:
Jianghai
2024-03-01 14:47:36 +08:00
committed by CjhHa1
parent d482922035
commit 69cd7e069d
13 changed files with 789 additions and 25 deletions

View File

@@ -263,24 +263,27 @@ class RequestHandler:
), f"Sequence {req.request_id} exceeds input length limit"
self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req)
def abort_sequence(self, request_id: str):
def abort_sequence(self, request_id: int):
"""
Abort the request.
"""
seq, priority = self._find_sequence(request_id)
if seq.status == RequestStatus.WAITING:
seq.mark_aborted()
self.waiting_list[priority].remove(seq)
elif seq.status.is_running():
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
self.running_list.remove(seq)
else:
try:
self.done_list.remove(seq)
except:
return
result = self._find_sequence(request_id)
if result is not None:
seq, priority = result
if seq.status == RequestStatus.WAITING:
seq.mark_aborted()
self.waiting_list[priority].remove(seq)
elif seq.status.is_running():
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
self.running_list.remove(seq)
else:
try:
self.done_list.remove(seq)
except:
return
return
def _find_sequence(self, request_id: str) -> Sequence:
def _find_sequence(self, request_id: int) -> Sequence:
"""
Find the request by request_id.
"""
@@ -324,6 +327,9 @@ class RequestHandler:
def check_unfinished_seqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty()
def current_requests_in_batch(self) -> int:
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
def search_tokens(self, generation_config: GenerationConfig, logits):
"""
Sample tokens for finished requests.