mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[Inference] Add the logic of the inference engine (#5173)
* add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and struct * Add the logic of the inference engine * update engine and test * Recover cache_manager.py * add logger * fix conflict * update codes * update codes * update model and tokenizer * fix add the logic about shardformer * change kvcache_manager docstring * add policy * fix ci bug in test_kvcache_manager.py * remove codes related o tokenizer and move model_policy * fix code style * add ordered_set to requirements-infer.txt * Delete extra empty lines * add ordered_set to requirements-test.txt
This commit is contained in:
committed by
FrankLeeeee
parent
93aeacca34
commit
8daee26989
@@ -1,68 +1,82 @@
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Set
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from ordered_set import OrderedSet
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
"""
|
||||
The abstraction of request and sequence are defined here.
|
||||
"""
|
||||
|
||||
|
||||
class RequsetStatus(enum.Enum):
|
||||
"""The status of Sentences"""
|
||||
class RequestStatus(enum.Enum):
|
||||
"""
|
||||
The status of Sentences
|
||||
"""
|
||||
|
||||
# running status
|
||||
WAITING = enum.auto()
|
||||
RUNNING = enum.auto()
|
||||
PREFILL = enum.auto()
|
||||
TOKEN = enum.auto()
|
||||
ABORTED = enum.auto()
|
||||
|
||||
# completion status
|
||||
OVERLENGTH = enum.auto()
|
||||
COMPLETED = enum.auto()
|
||||
LENGTH_CAPPED = enum.auto()
|
||||
|
||||
@staticmethod
|
||||
def is_finished(status: "RequsetStatus") -> bool:
|
||||
def is_finished(status: "RequestStatus") -> bool:
|
||||
return status in [
|
||||
RequsetStatus.OVERLENGTH,
|
||||
RequsetStatus.COMPLETED,
|
||||
RequsetStatus.LENGTH_CAPPED,
|
||||
RequestStatus.OVERLENGTH,
|
||||
RequestStatus.COMPLETED,
|
||||
RequestStatus.LENGTH_CAPPED,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def is_running(status: "RequsetStatus") -> bool:
|
||||
return status == RequsetStatus.RUNNING
|
||||
def is_running(status: "RequestStatus") -> bool:
|
||||
return status in [
|
||||
RequestStatus.PREFILL,
|
||||
RequestStatus.TOKEN,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def is_waiting(status: "RequsetStatus") -> bool:
|
||||
return status == RequsetStatus.WAITING
|
||||
def is_waiting(status: "RequestStatus") -> bool:
|
||||
return status == RequestStatus.WAITING
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sequence:
|
||||
"""Store information of input sequence.
|
||||
|
||||
Args:
|
||||
request_id: The ID of input sequence.
|
||||
prompt: The prompt of input sequence.
|
||||
token_id: The tokens ID of input sequence.
|
||||
block_size: The block size of input sequence.
|
||||
sample_params: The sample_params of input sequence.
|
||||
block_table_index: The index of input sequence in block_table.
|
||||
request_id (int): The ID of input sequence.
|
||||
prompt (str): The prompt of input sequence.
|
||||
input_token_id (List[int]): The tokens ID of input sequence.
|
||||
block_size (int): The block size of input sequence.
|
||||
sample_params (SampleParams): The sample_params of input sequence.
|
||||
block_table (torch.Tensor): The index of input sequence in block_table.
|
||||
eos_token_id (int): The eos token id for this inference process.
|
||||
max_output_len (int): Maximum output length.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: int,
|
||||
prompt: str,
|
||||
token_id: List[int],
|
||||
block_size: int,
|
||||
sample_params, # SampleParams needs to be imported later.
|
||||
block_table_index: int,
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.input_token_id = token_id
|
||||
self.blokc_size = block_size
|
||||
self.sample_params = sample_params
|
||||
request_id: int
|
||||
prompt: str
|
||||
input_token_id: List[int]
|
||||
block_size: int
|
||||
sample_params: any # SampleParams needs to be imported later.
|
||||
block_table: torch.Tensor
|
||||
eos_token_id: int
|
||||
max_output_len: int = 256
|
||||
|
||||
def __post_init__(self):
|
||||
self.output_token_id = []
|
||||
self.status = RequsetStatus.WAITING
|
||||
self.block_table_index = block_table_index
|
||||
self.status = RequestStatus.WAITING
|
||||
|
||||
def get_sentence_len(self) -> None:
|
||||
"""
|
||||
@@ -84,17 +98,30 @@ class Sequence:
|
||||
|
||||
def check_finish(self) -> bool:
|
||||
"""
|
||||
Check whether inference is over.
|
||||
Check whether the inference is finished.
|
||||
|
||||
Returns:
|
||||
bool: Whether the inference is finished.
|
||||
"""
|
||||
return RequsetStatus.is_finished(self.status)
|
||||
if RequestStatus.is_finished(self.status):
|
||||
return True
|
||||
|
||||
if self.output_token_id:
|
||||
if self.output_token_id[-1] == self.eos_token_id or len(self.output_token_id) == self.max_output_len:
|
||||
self.status = RequestStatus.COMPLETED
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.request_id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Request ID(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt}, "
|
||||
f"status={self.status.name}, "
|
||||
f"sample_params={self.sample_params}, "
|
||||
f"logical block number={len(self._logical_blocks)}"
|
||||
f"sample_params={self.sample_params}"
|
||||
)
|
||||
|
||||
|
||||
@@ -104,34 +131,38 @@ class BatchInfo:
|
||||
Information to be passed and used for a batch of sequences.
|
||||
"""
|
||||
|
||||
sequences_set: Set[Sequence]
|
||||
block_table: Dict[int, int] = None
|
||||
sequences_set: OrderedSet["Sequence"]
|
||||
|
||||
@classmethod
|
||||
def init_batch(cls, seqs: List[Sequence]) -> "BatchInfo":
|
||||
def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo":
|
||||
"""
|
||||
Initializes inference batches by input sentence list.
|
||||
|
||||
Args:
|
||||
seqs (List[Sequence]): List of input sequence.
|
||||
seqs (List["Sequence"]): List of input sequence.
|
||||
"""
|
||||
sequences_set = set()
|
||||
block_table = {}
|
||||
for seq in seqs:
|
||||
if seq in sequences_set:
|
||||
assert (
|
||||
seq.request_id in block_table.keys()
|
||||
), "The sequence has been added to sequences_set, but it has not been added to block_table."
|
||||
continue
|
||||
|
||||
assert (
|
||||
seq.request_id not in block_table.keys()
|
||||
), "The sequence has not been added to sequences_set, but it is already in block_table."
|
||||
sequences_set = OrderedSet()
|
||||
|
||||
sequences_set.add(seq)
|
||||
block_table[seq.request_id] = seq.block_table_index
|
||||
if seqs is not None:
|
||||
if not isinstance(seqs, list):
|
||||
seqs = [seqs]
|
||||
for seq in seqs:
|
||||
if seq in sequences_set:
|
||||
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||
continue
|
||||
|
||||
return cls(sequences_set=sequences_set, block_table=block_table)
|
||||
sequences_set.add(seq)
|
||||
|
||||
return cls(sequences_set=sequences_set)
|
||||
|
||||
def get_block_table_tensor(self):
|
||||
tesnor_list = []
|
||||
for seq in self.sequences_set:
|
||||
block_table = seq.block_table
|
||||
assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
||||
tesnor_list.append(seq.block_table)
|
||||
return torch.concat(tesnor_list)
|
||||
|
||||
def clear_batch(self) -> None:
|
||||
"""
|
||||
@@ -139,35 +170,76 @@ class BatchInfo:
|
||||
"""
|
||||
for seq in self.sequences_set:
|
||||
if not seq.check_finish():
|
||||
seq.status = RequsetStatus.ABORTED
|
||||
seq.status = RequestStatus.ABORTED
|
||||
self.sequences_set.clear()
|
||||
self.block_table.clear()
|
||||
|
||||
def fliter_batch(self) -> None:
|
||||
def fliter_batch(self) -> List["Sequence"]:
|
||||
"""
|
||||
Remove completed sentences from a batch.
|
||||
"""
|
||||
for seq in self.sequences_set.copy():
|
||||
if seq.check_finish():
|
||||
self.sequences_set.remove(seq)
|
||||
del self.block_table[seq.request_id]
|
||||
|
||||
def add_seqs(self, seqs: List[Sequence]) -> None:
|
||||
Returns:
|
||||
List["Sequence"]: List of finished sequences.
|
||||
"""
|
||||
finish_seqs = []
|
||||
for seq in self.sequences_set:
|
||||
if seq.check_finish():
|
||||
finish_seqs.append(seq)
|
||||
for finish_seq in finish_seqs:
|
||||
self.sequences_set.discard(finish_seq)
|
||||
return finish_seqs
|
||||
|
||||
def abort_seq(self, seq: "Sequence") -> "Sequence":
|
||||
"""
|
||||
Remove sequence from the batch.
|
||||
"""
|
||||
if not seq.check_finish():
|
||||
seq.status = RequestStatus.ABORTED
|
||||
self.sequences_set.discard(seq)
|
||||
return seq
|
||||
|
||||
def add_seqs(self, seqs: List["Sequence"]) -> None:
|
||||
"""
|
||||
Add new sequence to batch
|
||||
|
||||
Args:
|
||||
seqs (List[Sequence]): The list of new sequences.
|
||||
seqs (List["Sequence"]): The list of new sequences.
|
||||
"""
|
||||
|
||||
if not isinstance(seqs, list):
|
||||
seqs = [seqs]
|
||||
|
||||
for seq in seqs:
|
||||
if seq in self.sequences_set:
|
||||
print("The sequence is already in sequences_set.")
|
||||
assert (
|
||||
seq.request_id in self.block_table
|
||||
), "The sequence has been added to sequences_set, but it has not been added to block_table."
|
||||
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||
continue
|
||||
assert (
|
||||
seq.request_id not in self.block_table
|
||||
), "The sequence has not been added to sequences_set, but it is already in block_table."
|
||||
self.sequences_set.add(seq)
|
||||
self.block_table[seq.request_id] = seq.block_table_index
|
||||
|
||||
def is_empty(self) -> None:
|
||||
"""
|
||||
Check whether sequences_set is empty.
|
||||
"""
|
||||
return not self.sequences_set
|
||||
|
||||
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]]]) -> None:
|
||||
"""
|
||||
Add an output token for each sentence in the batch.
|
||||
|
||||
Args:
|
||||
tokens (List[int]): A batch of tokens
|
||||
"""
|
||||
|
||||
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
|
||||
|
||||
for seq, token in zip(self.sequences_set, tokens):
|
||||
if not isinstance(token, list):
|
||||
if not isinstance(token, int):
|
||||
raise TypeError(f"The token type must be List[int] or int, but get {type(token)}.")
|
||||
token = [token]
|
||||
seq.output_token_id += token
|
||||
seq.check_finish()
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
"""
|
||||
Get batch_size of this batch
|
||||
"""
|
||||
return len(self.sequences_set)
|
||||
|
Reference in New Issue
Block a user