mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[Inference] add logit processor and request handler (#5166)
* add logit processor and request handler * add * add * add * fix * add search tokens and update func * finish request handler * add running list test * fix test * fix some bug * add * add * fix bugs * fix some bugs * fix bug * fix * fix * add copy fun * del useless attn * fix request status --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
from typing import Any, List, Union
|
||||
|
||||
import torch
|
||||
from ordered_set import OrderedSet
|
||||
@@ -21,8 +21,7 @@ class RequestStatus(enum.Enum):
|
||||
|
||||
# running status
|
||||
WAITING = enum.auto()
|
||||
PREFILL = enum.auto()
|
||||
TOKEN = enum.auto()
|
||||
RUNNING = enum.auto()
|
||||
ABORTED = enum.auto()
|
||||
|
||||
# completion status
|
||||
@@ -40,10 +39,7 @@ class RequestStatus(enum.Enum):
|
||||
|
||||
@staticmethod
|
||||
def is_running(status: "RequestStatus") -> bool:
|
||||
return status in [
|
||||
RequestStatus.PREFILL,
|
||||
RequestStatus.TOKEN,
|
||||
]
|
||||
return status == RequestStatus.RUNNING
|
||||
|
||||
@staticmethod
|
||||
def is_waiting(status: "RequestStatus") -> bool:
|
||||
@@ -69,7 +65,7 @@ class Sequence:
|
||||
prompt: str
|
||||
input_token_id: List[int]
|
||||
block_size: int
|
||||
sample_params: any # SampleParams needs to be imported later.
|
||||
sample_params: Any # SampleParams needs to be imported later.
|
||||
block_table: torch.Tensor
|
||||
eos_token_id: int
|
||||
max_output_len: int = 256
|
||||
@@ -78,21 +74,31 @@ class Sequence:
|
||||
self.output_token_id = []
|
||||
self.status = RequestStatus.WAITING
|
||||
|
||||
def get_sentence_len(self) -> None:
|
||||
@property
|
||||
def prompt_len(self) -> int:
|
||||
"""
|
||||
Get length of prompts
|
||||
"""
|
||||
return len(self.input_token_id)
|
||||
|
||||
@property
|
||||
def sentence_len(self) -> int:
|
||||
"""
|
||||
Get length of current sentence.
|
||||
"""
|
||||
return len(self.input_token_id) + len(self.output_token_id)
|
||||
|
||||
def get_input_len(self) -> None:
|
||||
@property
|
||||
def input_len(self) -> int:
|
||||
"""
|
||||
Get length of input sentence.
|
||||
"""
|
||||
return len(self.input_token_id)
|
||||
|
||||
def get_output_len(self) -> None:
|
||||
@property
|
||||
def output_len(self) -> int:
|
||||
"""
|
||||
Get output length of current sentence.
|
||||
Get length of output sentence.
|
||||
"""
|
||||
return len(self.output_token_id)
|
||||
|
||||
@@ -116,12 +122,32 @@ class Sequence:
|
||||
def __hash__(self):
|
||||
return hash(self.request_id)
|
||||
|
||||
def mark_running(self) -> None:
|
||||
"""
|
||||
Set status for prefill reqs.
|
||||
"""
|
||||
assert self.status == RequestStatus.WAITING, "Sequence is not in WAITTING STATUS"
|
||||
self.status = RequestStatus.RUNNING
|
||||
|
||||
def mark_finished(self) -> None:
|
||||
"""
|
||||
Set status for finished reqs.
|
||||
"""
|
||||
self.status = RequestStatus.COMPLETED
|
||||
|
||||
def mark_aborted(self) -> None:
|
||||
"""
|
||||
Set status for aborted reqs.
|
||||
"""
|
||||
self.status = RequestStatus.ABORTED
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Request ID(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt}, "
|
||||
f"status={self.status.name}, "
|
||||
f"sample_params={self.sample_params}"
|
||||
f"sample_params={self.sample_params}, "
|
||||
f"logical block number={len(self.block_table_index)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -131,7 +157,8 @@ class BatchInfo:
|
||||
Information to be passed and used for a batch of sequences.
|
||||
"""
|
||||
|
||||
sequences_set: OrderedSet["Sequence"]
|
||||
sequences_set: OrderedSet["Sequence"] = None
|
||||
is_prompts: bool = True
|
||||
|
||||
@classmethod
|
||||
def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo":
|
||||
@@ -214,6 +241,7 @@ class BatchInfo:
|
||||
continue
|
||||
self.sequences_set.add(seq)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> None:
|
||||
"""
|
||||
Check whether sequences_set is empty.
|
||||
|
Reference in New Issue
Block a user