ColossalAI/colossalai/inference/struct.py
Yuanheng Zhao 55cc7f3df7
[Fix] Fix Inference Example, Tests, and Requirements (#5688)
* clean requirements

* modify example inference struct

* add test ci scripts

* mark test_infer as submodule

* rm deprecated cls & deps

* import of HAS_FLASH_ATTN

* prune inference tests to be run

* prune triton kernel tests

* increment pytest timeout mins

* revert import path in openmoe
2024-05-08 11:30:15 +08:00

172 lines
4.6 KiB
Python

import enum
from dataclasses import dataclass
from typing import Any, List
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
"""
The abstraction of request and sequence are defined here.
"""
class RequestStatus(enum.Enum):
"""
The status of Sentences
"""
# running status
WAITING = enum.auto()
RUNNING = enum.auto()
ABORTED = enum.auto()
# completion status
OVERLENGTH = enum.auto()
COMPLETED = enum.auto()
LENGTH_CAPPED = enum.auto()
# recycle status
RECYCLED = enum.auto()
@staticmethod
def is_finished(status: "RequestStatus") -> bool:
return status in [
RequestStatus.OVERLENGTH,
RequestStatus.COMPLETED,
RequestStatus.LENGTH_CAPPED,
]
@staticmethod
def is_running(status: "RequestStatus") -> bool:
return status == RequestStatus.RUNNING
@staticmethod
def is_waiting(status: "RequestStatus") -> bool:
return status == RequestStatus.WAITING
@dataclass
class Sequence:
"""Store information of input sequence.
Args:
request_id (int): The ID of input sequence.
prompt (str): The prompt of input sequence.
input_token_id (List[int]): The tokens ID of input sequence.
block_size (int): The block size of input sequence.
sample_params (SampleParams): The sample_params of input sequence.
block_table (torch.Tensor): The index of input sequence in block_table.
eos_token_id (int): The eos token id for this inference process.
pad_token_id (int): The pad token id for this inference process.
max_output_len (int): Maximum output length.
"""
request_id: int
prompt: str
input_token_id: List[int]
block_size: int
sample_params: Any # SampleParams needs to be imported later.
eos_token_id: int
pad_token_id: int
max_output_len: int = 256
def __post_init__(self):
self.output_token_id = []
self.status = RequestStatus.WAITING
@property
def sentence_len(self) -> int:
"""
Get length of current sentence.
"""
return len(self.input_token_id) + len(self.output_token_id)
@property
def input_len(self) -> int:
"""
Get length of input sentence.
"""
return len(self.input_token_id)
@property
def output_len(self) -> int:
"""
Get length of output sentence.
"""
return len(self.output_token_id)
def check_finish(self) -> bool:
"""
Check whether the inference is finished.
Returns:
bool: Whether the inference is finished.
"""
if RequestStatus.is_finished(self.status):
return True
if self.output_token_id:
if self.output_token_id[-1] == self.eos_token_id or self.output_len >= self.max_output_len:
self.status = RequestStatus.COMPLETED
return True
return False
def revoke_finished_status(self) -> None:
"""
Revoke the finished status of the sequence.
This is only used by speculative decoding for now.
"""
if RequestStatus.is_finished(self.status):
self.status = RequestStatus.RUNNING
def __hash__(self):
return hash(self.request_id)
def mark_running(self) -> None:
"""
Set status for prefill reqs.
"""
assert (
self.status == RequestStatus.WAITING or RequestStatus.RECYCLED
), "Sequence is not in WAITTING/RECYCLED STATUS"
self.status = RequestStatus.RUNNING
def mark_finished(self) -> None:
"""
Set status for finished reqs.
"""
self.status = RequestStatus.COMPLETED
def mark_aborted(self) -> None:
"""
Set status for aborted reqs.
"""
self.status = RequestStatus.ABORTED
def recycle(self) -> None:
"""
Recycle a running sequnce to waiitting list
"""
assert (
not self.check_finish() and not self.status == RequestStatus.ABORTED
), "The running sequence \
is already done but it still in running list"
self.status = RequestStatus.RECYCLED
def __repr__(self) -> str:
return (
f"(request_id={self.request_id}, "
f"prompt={self.prompt}, "
f"status={self.status.name}, "
f"sample_params={self.sample_params}, "
f"input_len={self.input_len},"
f"output_len={self.output_len})"
)
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len
return [pad] * (max_len - len(x)) + x