mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[Inference] Optimize and Refactor Inference Batching/Scheduling (#5367)
* add kvcache manager funcs for batching * add batch bucket for batching * revise RunningList struct in handler * add kvcache/batch funcs for compatibility * use new batching methods * fix indexing bugs * revise abort logic * use cpu seq lengths/block tables * rm unused attr in Sequence * fix type conversion/default arg * add and revise pytests * revise pytests, rm unused tests * rm unused statements * fix pop finished indexing issue * fix: use index in batch when retrieving inputs/update seqs * use dict instead of odict in batch struct * arg type hinting * fix make compress * refine comments * fix: pop_n_seqs to pop the first n seqs * add check in request handler * remove redundant conversion * fix test for request handler * fix pop method in batch bucket * fix prefill adding
This commit is contained in:
@@ -1,15 +1,16 @@
|
||||
from typing import List
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.kv_cache import KVCacheManager
|
||||
from colossalai.inference.logit_processors import logit_processor
|
||||
from colossalai.inference.sampler import *
|
||||
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
|
||||
from colossalai.inference.struct import RequestStatus, Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
__all__ = ["RunningList", "RequestHandler"]
|
||||
@@ -24,45 +25,79 @@ class RunningList:
|
||||
|
||||
Args:
|
||||
prefill_ratio: (float) A ratio for determing whether to perform prefill or not.
|
||||
prefill: (List) List that contains default inputs, defaults to [].
|
||||
_prefill (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
|
||||
_decoding (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
|
||||
"""
|
||||
|
||||
def __init__(self, prefill_ratio: str, prefill: List[Sequence] = None):
|
||||
def __init__(self, prefill_ratio: int, prefill: List[Sequence] = None) -> None:
|
||||
self.prefill_ratio = prefill_ratio
|
||||
self.decoding: List[Sequence] = []
|
||||
self.prefill: List[Sequence] = prefill if prefill is not None else []
|
||||
self._decoding: Dict[int, Sequence] = dict()
|
||||
self._prefill: Dict[int, Sequence] = (
|
||||
dict({seq.request_id: seq for seq in self._prefill}) if prefill is not None else dict()
|
||||
)
|
||||
|
||||
@property
|
||||
def decoding(self):
|
||||
return list(self._decoding.values())
|
||||
|
||||
@property
|
||||
def prefill(self):
|
||||
return list(self._prefill.values())
|
||||
|
||||
@property
|
||||
def prefill_seq_num(self):
|
||||
return len(self._prefill)
|
||||
|
||||
@property
|
||||
def decoding_seq_num(self):
|
||||
return len(self._decoding)
|
||||
|
||||
@property
|
||||
def total_seq_num(self):
|
||||
return self.prefill_seq_num + self.decoding_seq_num
|
||||
|
||||
def append(self, seq: Sequence):
|
||||
# add seq to prefilling list first.
|
||||
self.prefill.append(seq)
|
||||
assert (seq.request_id not in self._prefill) and (
|
||||
seq.request_id not in self._decoding
|
||||
), f"Sequence uid {seq.request_id} already exists."
|
||||
self._prefill[seq.request_id] = seq
|
||||
|
||||
def find_seq(self, request_id):
|
||||
for seq in self.decoding:
|
||||
if request_id == seq.request_id:
|
||||
return seq
|
||||
for seq in self.prefill:
|
||||
if request_id == seq.request_id:
|
||||
return seq
|
||||
return None
|
||||
def extend(self, seqs: List[Sequence]):
|
||||
for seq in seqs:
|
||||
self._prefill[seq.request_id] = seq
|
||||
|
||||
def remove(self, seq: Sequence):
|
||||
if seq in self.decoding:
|
||||
self.decoding.remove(seq)
|
||||
elif seq in self.prefill:
|
||||
self.prefill.remove(seq)
|
||||
def find_seq(self, request_id) -> Union[Sequence, None]:
|
||||
seq = None
|
||||
if request_id in self._decoding:
|
||||
seq = self._decoding[request_id]
|
||||
elif request_id in self._prefill:
|
||||
seq = self._prefill[request_id]
|
||||
return seq
|
||||
|
||||
def remove(self, seq: Sequence) -> None:
|
||||
if seq.request_id in self._decoding:
|
||||
self._decoding.pop(seq.request_id)
|
||||
elif seq.request_id in self._prefill:
|
||||
self._prefill.pop(seq.request_id)
|
||||
else:
|
||||
raise ValueError(f"sequence {seq.request_id} is not in running list")
|
||||
raise ValueError(f"Sequence {seq.request_id} is not in running list")
|
||||
|
||||
def ready_for_prefill(self):
|
||||
if not self.decoding:
|
||||
return len(self.prefill) > 0
|
||||
return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
|
||||
if not self._decoding:
|
||||
return len(self._prefill) > 0
|
||||
return len(self._prefill) / len(self._decoding) >= self.prefill_ratio
|
||||
|
||||
def is_empty(self):
|
||||
return not self.decoding and not self.prefill
|
||||
return not self._decoding and not self._prefill
|
||||
|
||||
def total_seq_num(self):
|
||||
return len(self.decoding) + len(self.prefill)
|
||||
def mark_prefill_running(self) -> None:
|
||||
for seq_id in self._prefill:
|
||||
self._prefill[seq_id].mark_running()
|
||||
|
||||
def move_prefill_to_decoding(self, seq_ids: List[int]) -> None:
|
||||
for seq_id in seq_ids:
|
||||
assert seq_id in self._prefill, f"Sequence {seq_id} is not in prefill list"
|
||||
self._decoding[seq_id] = self._prefill.pop(seq_id)
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
@@ -110,25 +145,27 @@ class RequestHandler:
|
||||
|
||||
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
||||
# which may cause bugs and this issue should be fixed later.
|
||||
self.running_batch = BatchInfo(
|
||||
max_batch_size=self.max_batch_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
self.running_bb = BatchBucket(
|
||||
num_heads=model_config.num_attention_heads,
|
||||
head_dim=head_dim,
|
||||
is_prompts=False,
|
||||
device=device,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||
block_size=inference_config.block_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
fd_interm_tensor=fd_inter_tensor,
|
||||
dtype=self.dtype,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
device=device,
|
||||
)
|
||||
self.prefill_batch = BatchInfo(
|
||||
max_batch_size=self.max_batch_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
self.prefill_bb = BatchBucket(
|
||||
num_heads=model_config.num_attention_heads,
|
||||
head_dim=head_dim,
|
||||
is_prompts=True,
|
||||
device=device,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||
block_size=inference_config.block_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
fd_interm_tensor=fd_inter_tensor,
|
||||
dtype=self.dtype,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def _init_cache(self, model_config):
|
||||
@@ -159,40 +196,39 @@ class RequestHandler:
|
||||
remove_list.append(seq)
|
||||
break
|
||||
|
||||
# stop feeding new sequence into running list to assure
|
||||
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num():
|
||||
break
|
||||
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
|
||||
remove_list.extend(lst[:num_seqs_to_add])
|
||||
self.running_list.extend(lst[:num_seqs_to_add])
|
||||
|
||||
# Try to allocate cache blocks for the sequence.
|
||||
if (
|
||||
self.cache_manager.check_allocation(seq)
|
||||
and (len(self.running_list.prefill) + len(self.running_list.decoding))
|
||||
< self.max_batch_size # There some bugs in continous batching, so we disable it here.
|
||||
):
|
||||
# If succeed, add the sequence to running list.
|
||||
remove_list.append(seq)
|
||||
self.running_list.append(seq)
|
||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len)
|
||||
for seq in remove_list:
|
||||
lst.remove(seq)
|
||||
|
||||
if self.running_list.ready_for_prefill():
|
||||
for seq in self.running_list.prefill:
|
||||
seq.mark_running()
|
||||
self.prefill_batch.add_seqs(self.running_list.prefill)
|
||||
return self.prefill_batch
|
||||
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.running_bb.available_batch_size)
|
||||
|
||||
if not self.running_batch.is_empty:
|
||||
for seq in self.running_batch.sequences_set:
|
||||
recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
|
||||
if recycle:
|
||||
for seq in self.running_list.prefill[:num_seqs_to_add]:
|
||||
seq.mark_running()
|
||||
# allocate blocks for the prefill batch
|
||||
self.prefill_bb.add_seqs(
|
||||
self.running_list.prefill[:num_seqs_to_add],
|
||||
alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables,
|
||||
)
|
||||
|
||||
return self.prefill_bb
|
||||
|
||||
if not self.running_bb.is_empty:
|
||||
seqs_ids_to_recycle = self.cache_manager.allocate_tokens_from_block_tables(
|
||||
self.running_bb.block_tables, self.running_bb.seq_lengths, self.running_bb.current_batch_size
|
||||
)
|
||||
if seqs_ids_to_recycle:
|
||||
seqs_to_recycle = self.running_bb.pop_seqs(seqs_ids_to_recycle)
|
||||
for seq in seqs_to_recycle:
|
||||
seq.recycle()
|
||||
self.running_batch.del_seq(seq)
|
||||
self.running_list.remove(seq)
|
||||
self.waiting_list[-1].append(seq)
|
||||
# the recycled sequences are handled with highest priority.
|
||||
|
||||
return self.running_batch
|
||||
return self.running_bb
|
||||
|
||||
def add_sequence(self, req: Sequence):
|
||||
"""
|
||||
@@ -213,7 +249,7 @@ class RequestHandler:
|
||||
seq.mark_aborted()
|
||||
self.waiting_list[priority].remove(seq)
|
||||
elif seq.status.is_running():
|
||||
self.cache_manager.free_block_table(seq.block_table)
|
||||
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
|
||||
self.running_list.remove(seq)
|
||||
else:
|
||||
try:
|
||||
@@ -242,7 +278,7 @@ class RequestHandler:
|
||||
else:
|
||||
sample_tokens = greedy_sample(generation_config, logprobs)
|
||||
else:
|
||||
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_batch.is_empty)
|
||||
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty)
|
||||
|
||||
return sample_tokens
|
||||
|
||||
@@ -273,27 +309,25 @@ class RequestHandler:
|
||||
|
||||
# sample the next tokens
|
||||
sample_tokens = self._sample(probs, logprobs, generation_config)
|
||||
if not self.prefill_batch.is_empty:
|
||||
self.prefill_batch.update_batch_tokens(sample_tokens)
|
||||
if not self.prefill_bb.is_empty:
|
||||
self.prefill_bb.append_batch_tokens(sample_tokens)
|
||||
else:
|
||||
self.running_batch.update_batch_tokens(sample_tokens)
|
||||
self.running_bb.append_batch_tokens(sample_tokens)
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
Update current running list and done list
|
||||
"""
|
||||
if not self.prefill_batch.is_empty:
|
||||
self.running_list.decoding.extend(self.running_list.prefill)
|
||||
self.running_batch.add_seqs(self.running_list.prefill)
|
||||
self.running_list.prefill.clear()
|
||||
self.prefill_batch.clear_batch()
|
||||
if not self.prefill_bb.is_empty:
|
||||
self.running_list.move_prefill_to_decoding(self.prefill_bb.seqs_ids)
|
||||
self.running_bb.merge(self.prefill_bb)
|
||||
# clear the prefill batch without assigning a free_block_tables_fn
|
||||
# since we want to reuse the memory recorded on the block tables
|
||||
self.prefill_bb.clear(free_block_tables_fn=None)
|
||||
|
||||
finish_seqs = self.running_batch.fliter_batch()
|
||||
|
||||
for seq in finish_seqs:
|
||||
finished_seqs, _ = self.running_bb.pop_finished(self.cache_manager.free_block_table)
|
||||
for seq in finished_seqs:
|
||||
self.running_list.remove(seq)
|
||||
self.cache_manager.free_block_table(seq.block_table)
|
||||
self.done_list.extend(finished_seqs)
|
||||
|
||||
self.done_list.extend(finish_seqs)
|
||||
|
||||
return finish_seqs
|
||||
return finished_seqs
|
||||
|
Reference in New Issue
Block a user