[Inference] add logit processor and request handler (#5166)

* add logit processor and request handler

* add

* add

* add

* fix

* add search tokens and update func

* finish request handler

* add running list test

* fix test

* fix some bug

* add

* add

* fix bugs

* fix some bugs

* fix bug

* fix

* fix

* add copy fun

* del useless attn

* fix request status

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
Jianghai
2023-12-25 12:15:15 +08:00
committed by FrankLeeeee
parent 8daee26989
commit 0e616462a7
10 changed files with 463 additions and 66 deletions

View File

@@ -1,6 +1,6 @@
import enum
from dataclasses import dataclass
from typing import List, Union
from typing import Any, List, Union
import torch
from ordered_set import OrderedSet
@@ -21,8 +21,7 @@ class RequestStatus(enum.Enum):
# running status
WAITING = enum.auto()
PREFILL = enum.auto()
TOKEN = enum.auto()
RUNNING = enum.auto()
ABORTED = enum.auto()
# completion status
@@ -40,10 +39,7 @@ class RequestStatus(enum.Enum):
@staticmethod
def is_running(status: "RequestStatus") -> bool:
return status in [
RequestStatus.PREFILL,
RequestStatus.TOKEN,
]
return status == RequestStatus.RUNNING
@staticmethod
def is_waiting(status: "RequestStatus") -> bool:
@@ -69,7 +65,7 @@ class Sequence:
prompt: str
input_token_id: List[int]
block_size: int
sample_params: any # SampleParams needs to be imported later.
sample_params: Any # SampleParams needs to be imported later.
block_table: torch.Tensor
eos_token_id: int
max_output_len: int = 256
@@ -78,21 +74,31 @@ class Sequence:
self.output_token_id = []
self.status = RequestStatus.WAITING
def get_sentence_len(self) -> None:
@property
def prompt_len(self) -> int:
"""
Get length of prompts
"""
return len(self.input_token_id)
@property
def sentence_len(self) -> int:
"""
Get length of current sentence.
"""
return len(self.input_token_id) + len(self.output_token_id)
def get_input_len(self) -> None:
@property
def input_len(self) -> int:
"""
Get length of input sentence.
"""
return len(self.input_token_id)
def get_output_len(self) -> None:
@property
def output_len(self) -> int:
"""
Get output length of current sentence.
Get length of output sentence.
"""
return len(self.output_token_id)
@@ -116,12 +122,32 @@ class Sequence:
def __hash__(self):
return hash(self.request_id)
def mark_running(self) -> None:
"""
Set status for prefill reqs.
"""
assert self.status == RequestStatus.WAITING, "Sequence is not in WAITTING STATUS"
self.status = RequestStatus.RUNNING
def mark_finished(self) -> None:
"""
Set status for finished reqs.
"""
self.status = RequestStatus.COMPLETED
def mark_aborted(self) -> None:
"""
Set status for aborted reqs.
"""
self.status = RequestStatus.ABORTED
def __repr__(self) -> str:
return (
f"Request ID(request_id={self.request_id}, "
f"prompt={self.prompt}, "
f"status={self.status.name}, "
f"sample_params={self.sample_params}"
f"sample_params={self.sample_params}, "
f"logical block number={len(self.block_table_index)}"
)
@@ -131,7 +157,8 @@ class BatchInfo:
Information to be passed and used for a batch of sequences.
"""
sequences_set: OrderedSet["Sequence"]
sequences_set: OrderedSet["Sequence"] = None
is_prompts: bool = True
@classmethod
def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo":
@@ -214,6 +241,7 @@ class BatchInfo:
continue
self.sequences_set.add(seq)
@property
def is_empty(self) -> None:
"""
Check whether sequences_set is empty.