[Inference] Add the logic of the inference engine (#5173)

* add infer_struct and infer_config

* update codes

* change InferConfig

* Add hf_model_config to the engine

* rm _get_hf_model_config

* update codes

* made adjustments according to the feedback from the reviewer.

* update codes

* add ci test for config and struct

* Add the logic of the inference engine

* update engine and test

* Recover cache_manager.py

* add logger

* fix conflict

* update codes

* update codes

* update model and tokenizer

* fix add the logic about shardformer

* change kvcache_manager docstring

* add policy

* fix ci bug in test_kvcache_manager.py

* remove codes related o tokenizer and move model_policy

* fix  code style

* add ordered_set to requirements-infer.txt

* Delete extra empty lines

* add ordered_set to requirements-test.txt
This commit is contained in:
yuehuayingxueluo
2023-12-18 10:40:47 +08:00
committed by FrankLeeeee
parent 93aeacca34
commit 8daee26989
13 changed files with 555 additions and 172 deletions

View File

@@ -1,68 +1,82 @@
import enum
from dataclasses import dataclass
from typing import Dict, List, Set
from typing import List, Union
import torch
from ordered_set import OrderedSet
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
"""
The abstraction of request and sequence are defined here.
"""
class RequsetStatus(enum.Enum):
"""The status of Sentences"""
class RequestStatus(enum.Enum):
"""
The status of Sentences
"""
# running status
WAITING = enum.auto()
RUNNING = enum.auto()
PREFILL = enum.auto()
TOKEN = enum.auto()
ABORTED = enum.auto()
# completion status
OVERLENGTH = enum.auto()
COMPLETED = enum.auto()
LENGTH_CAPPED = enum.auto()
@staticmethod
def is_finished(status: "RequsetStatus") -> bool:
def is_finished(status: "RequestStatus") -> bool:
return status in [
RequsetStatus.OVERLENGTH,
RequsetStatus.COMPLETED,
RequsetStatus.LENGTH_CAPPED,
RequestStatus.OVERLENGTH,
RequestStatus.COMPLETED,
RequestStatus.LENGTH_CAPPED,
]
@staticmethod
def is_running(status: "RequsetStatus") -> bool:
return status == RequsetStatus.RUNNING
def is_running(status: "RequestStatus") -> bool:
return status in [
RequestStatus.PREFILL,
RequestStatus.TOKEN,
]
@staticmethod
def is_waiting(status: "RequsetStatus") -> bool:
return status == RequsetStatus.WAITING
def is_waiting(status: "RequestStatus") -> bool:
return status == RequestStatus.WAITING
@dataclass
class Sequence:
"""Store information of input sequence.
Args:
request_id: The ID of input sequence.
prompt: The prompt of input sequence.
token_id: The tokens ID of input sequence.
block_size: The block size of input sequence.
sample_params: The sample_params of input sequence.
block_table_index: The index of input sequence in block_table.
request_id (int): The ID of input sequence.
prompt (str): The prompt of input sequence.
input_token_id (List[int]): The tokens ID of input sequence.
block_size (int): The block size of input sequence.
sample_params (SampleParams): The sample_params of input sequence.
block_table (torch.Tensor): The index of input sequence in block_table.
eos_token_id (int): The eos token id for this inference process.
max_output_len (int): Maximum output length.
"""
def __init__(
self,
request_id: int,
prompt: str,
token_id: List[int],
block_size: int,
sample_params, # SampleParams needs to be imported later.
block_table_index: int,
):
self.request_id = request_id
self.prompt = prompt
self.input_token_id = token_id
self.blokc_size = block_size
self.sample_params = sample_params
request_id: int
prompt: str
input_token_id: List[int]
block_size: int
sample_params: any # SampleParams needs to be imported later.
block_table: torch.Tensor
eos_token_id: int
max_output_len: int = 256
def __post_init__(self):
self.output_token_id = []
self.status = RequsetStatus.WAITING
self.block_table_index = block_table_index
self.status = RequestStatus.WAITING
def get_sentence_len(self) -> None:
"""
@@ -84,17 +98,30 @@ class Sequence:
def check_finish(self) -> bool:
"""
Check whether inference is over.
Check whether the inference is finished.
Returns:
bool: Whether the inference is finished.
"""
return RequsetStatus.is_finished(self.status)
if RequestStatus.is_finished(self.status):
return True
if self.output_token_id:
if self.output_token_id[-1] == self.eos_token_id or len(self.output_token_id) == self.max_output_len:
self.status = RequestStatus.COMPLETED
return True
return False
def __hash__(self):
return hash(self.request_id)
def __repr__(self) -> str:
return (
f"Request ID(request_id={self.request_id}, "
f"prompt={self.prompt}, "
f"status={self.status.name}, "
f"sample_params={self.sample_params}, "
f"logical block number={len(self._logical_blocks)}"
f"sample_params={self.sample_params}"
)
@@ -104,34 +131,38 @@ class BatchInfo:
Information to be passed and used for a batch of sequences.
"""
sequences_set: Set[Sequence]
block_table: Dict[int, int] = None
sequences_set: OrderedSet["Sequence"]
@classmethod
def init_batch(cls, seqs: List[Sequence]) -> "BatchInfo":
def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo":
"""
Initializes inference batches by input sentence list.
Args:
seqs (List[Sequence]): List of input sequence.
seqs (List["Sequence"]): List of input sequence.
"""
sequences_set = set()
block_table = {}
for seq in seqs:
if seq in sequences_set:
assert (
seq.request_id in block_table.keys()
), "The sequence has been added to sequences_set, but it has not been added to block_table."
continue
assert (
seq.request_id not in block_table.keys()
), "The sequence has not been added to sequences_set, but it is already in block_table."
sequences_set = OrderedSet()
sequences_set.add(seq)
block_table[seq.request_id] = seq.block_table_index
if seqs is not None:
if not isinstance(seqs, list):
seqs = [seqs]
for seq in seqs:
if seq in sequences_set:
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
continue
return cls(sequences_set=sequences_set, block_table=block_table)
sequences_set.add(seq)
return cls(sequences_set=sequences_set)
def get_block_table_tensor(self):
tesnor_list = []
for seq in self.sequences_set:
block_table = seq.block_table
assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
tesnor_list.append(seq.block_table)
return torch.concat(tesnor_list)
def clear_batch(self) -> None:
"""
@@ -139,35 +170,76 @@ class BatchInfo:
"""
for seq in self.sequences_set:
if not seq.check_finish():
seq.status = RequsetStatus.ABORTED
seq.status = RequestStatus.ABORTED
self.sequences_set.clear()
self.block_table.clear()
def fliter_batch(self) -> None:
def fliter_batch(self) -> List["Sequence"]:
"""
Remove completed sentences from a batch.
"""
for seq in self.sequences_set.copy():
if seq.check_finish():
self.sequences_set.remove(seq)
del self.block_table[seq.request_id]
def add_seqs(self, seqs: List[Sequence]) -> None:
Returns:
List["Sequence"]: List of finished sequences.
"""
finish_seqs = []
for seq in self.sequences_set:
if seq.check_finish():
finish_seqs.append(seq)
for finish_seq in finish_seqs:
self.sequences_set.discard(finish_seq)
return finish_seqs
def abort_seq(self, seq: "Sequence") -> "Sequence":
"""
Remove sequence from the batch.
"""
if not seq.check_finish():
seq.status = RequestStatus.ABORTED
self.sequences_set.discard(seq)
return seq
def add_seqs(self, seqs: List["Sequence"]) -> None:
"""
Add new sequence to batch
Args:
seqs (List[Sequence]): The list of new sequences.
seqs (List["Sequence"]): The list of new sequences.
"""
if not isinstance(seqs, list):
seqs = [seqs]
for seq in seqs:
if seq in self.sequences_set:
print("The sequence is already in sequences_set.")
assert (
seq.request_id in self.block_table
), "The sequence has been added to sequences_set, but it has not been added to block_table."
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
continue
assert (
seq.request_id not in self.block_table
), "The sequence has not been added to sequences_set, but it is already in block_table."
self.sequences_set.add(seq)
self.block_table[seq.request_id] = seq.block_table_index
def is_empty(self) -> None:
"""
Check whether sequences_set is empty.
"""
return not self.sequences_set
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]]]) -> None:
"""
Add an output token for each sentence in the batch.
Args:
tokens (List[int]): A batch of tokens
"""
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
for seq, token in zip(self.sequences_set, tokens):
if not isinstance(token, list):
if not isinstance(token, int):
raise TypeError(f"The token type must be List[int] or int, but get {type(token)}.")
token = [token]
seq.output_token_id += token
seq.check_finish()
def get_batch_size(self) -> int:
"""
Get batch_size of this batch
"""
return len(self.sequences_set)