[Inference] Optimize and Refactor Inference Batching/Scheduling (#5367)

* add kvcache manager funcs for batching

* add batch bucket for batching

* revise RunningList struct in handler

* add kvcache/batch funcs for compatibility

* use new batching methods

* fix indexing bugs

* revise abort logic

* use cpu seq lengths/block tables

* rm unused attr in Sequence

* fix type conversion/default arg

* add and revise pytests

* revise pytests, rm unused tests

* rm unused statements

* fix pop finished indexing issue

* fix: use index in batch when retrieving inputs/update seqs

* use dict instead of odict in batch struct

* arg type hinting

* fix make compress

* refine comments

* fix: pop_n_seqs to pop the first n seqs

* add check in request handler

* remove redundant conversion

* fix test for request handler

* fix pop method in batch bucket

* fix prefill adding
This commit is contained in:
Yuanheng Zhao
2024-02-19 17:18:20 +08:00
committed by GitHub
parent 8c69debdc7
commit b21aac5bae
11 changed files with 902 additions and 112 deletions

View File

@@ -1,15 +1,16 @@
from typing import List
from typing import Dict, List, Union
import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationConfig
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.logging import get_dist_logger
__all__ = ["RunningList", "RequestHandler"]
@@ -24,45 +25,79 @@ class RunningList:
Args:
prefill_ratio: (float) A ratio for determing whether to perform prefill or not.
prefill: (List) List that contains default inputs, defaults to [].
_prefill (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
_decoding (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
"""
def __init__(self, prefill_ratio: str, prefill: List[Sequence] = None):
def __init__(self, prefill_ratio: int, prefill: List[Sequence] = None) -> None:
self.prefill_ratio = prefill_ratio
self.decoding: List[Sequence] = []
self.prefill: List[Sequence] = prefill if prefill is not None else []
self._decoding: Dict[int, Sequence] = dict()
self._prefill: Dict[int, Sequence] = (
dict({seq.request_id: seq for seq in self._prefill}) if prefill is not None else dict()
)
@property
def decoding(self):
return list(self._decoding.values())
@property
def prefill(self):
return list(self._prefill.values())
@property
def prefill_seq_num(self):
return len(self._prefill)
@property
def decoding_seq_num(self):
return len(self._decoding)
@property
def total_seq_num(self):
return self.prefill_seq_num + self.decoding_seq_num
def append(self, seq: Sequence):
# add seq to prefilling list first.
self.prefill.append(seq)
assert (seq.request_id not in self._prefill) and (
seq.request_id not in self._decoding
), f"Sequence uid {seq.request_id} already exists."
self._prefill[seq.request_id] = seq
def find_seq(self, request_id):
for seq in self.decoding:
if request_id == seq.request_id:
return seq
for seq in self.prefill:
if request_id == seq.request_id:
return seq
return None
def extend(self, seqs: List[Sequence]):
for seq in seqs:
self._prefill[seq.request_id] = seq
def remove(self, seq: Sequence):
if seq in self.decoding:
self.decoding.remove(seq)
elif seq in self.prefill:
self.prefill.remove(seq)
def find_seq(self, request_id) -> Union[Sequence, None]:
seq = None
if request_id in self._decoding:
seq = self._decoding[request_id]
elif request_id in self._prefill:
seq = self._prefill[request_id]
return seq
def remove(self, seq: Sequence) -> None:
if seq.request_id in self._decoding:
self._decoding.pop(seq.request_id)
elif seq.request_id in self._prefill:
self._prefill.pop(seq.request_id)
else:
raise ValueError(f"sequence {seq.request_id} is not in running list")
raise ValueError(f"Sequence {seq.request_id} is not in running list")
def ready_for_prefill(self):
if not self.decoding:
return len(self.prefill) > 0
return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
if not self._decoding:
return len(self._prefill) > 0
return len(self._prefill) / len(self._decoding) >= self.prefill_ratio
def is_empty(self):
return not self.decoding and not self.prefill
return not self._decoding and not self._prefill
def total_seq_num(self):
return len(self.decoding) + len(self.prefill)
def mark_prefill_running(self) -> None:
for seq_id in self._prefill:
self._prefill[seq_id].mark_running()
def move_prefill_to_decoding(self, seq_ids: List[int]) -> None:
for seq_id in seq_ids:
assert seq_id in self._prefill, f"Sequence {seq_id} is not in prefill list"
self._decoding[seq_id] = self._prefill.pop(seq_id)
class RequestHandler:
@@ -110,25 +145,27 @@ class RequestHandler:
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_batch = BatchInfo(
max_batch_size=self.max_batch_size,
kv_max_split_num=kv_max_split_num,
self.running_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=False,
device=device,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
block_size=inference_config.block_size,
kv_max_split_num=kv_max_split_num,
fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
device=device,
)
self.prefill_batch = BatchInfo(
max_batch_size=self.max_batch_size,
kv_max_split_num=kv_max_split_num,
self.prefill_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=True,
device=device,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
block_size=inference_config.block_size,
kv_max_split_num=kv_max_split_num,
fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
device=device,
)
def _init_cache(self, model_config):
@@ -159,40 +196,39 @@ class RequestHandler:
remove_list.append(seq)
break
# stop feeding new sequence into running list to assure
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num():
break
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
remove_list.extend(lst[:num_seqs_to_add])
self.running_list.extend(lst[:num_seqs_to_add])
# Try to allocate cache blocks for the sequence.
if (
self.cache_manager.check_allocation(seq)
and (len(self.running_list.prefill) + len(self.running_list.decoding))
< self.max_batch_size # There some bugs in continous batching, so we disable it here.
):
# If succeed, add the sequence to running list.
remove_list.append(seq)
self.running_list.append(seq)
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len)
for seq in remove_list:
lst.remove(seq)
if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill:
seq.mark_running()
self.prefill_batch.add_seqs(self.running_list.prefill)
return self.prefill_batch
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.running_bb.available_batch_size)
if not self.running_batch.is_empty:
for seq in self.running_batch.sequences_set:
recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
if recycle:
for seq in self.running_list.prefill[:num_seqs_to_add]:
seq.mark_running()
# allocate blocks for the prefill batch
self.prefill_bb.add_seqs(
self.running_list.prefill[:num_seqs_to_add],
alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables,
)
return self.prefill_bb
if not self.running_bb.is_empty:
seqs_ids_to_recycle = self.cache_manager.allocate_tokens_from_block_tables(
self.running_bb.block_tables, self.running_bb.seq_lengths, self.running_bb.current_batch_size
)
if seqs_ids_to_recycle:
seqs_to_recycle = self.running_bb.pop_seqs(seqs_ids_to_recycle)
for seq in seqs_to_recycle:
seq.recycle()
self.running_batch.del_seq(seq)
self.running_list.remove(seq)
self.waiting_list[-1].append(seq)
# the recycled sequences are handled with highest priority.
return self.running_batch
return self.running_bb
def add_sequence(self, req: Sequence):
"""
@@ -213,7 +249,7 @@ class RequestHandler:
seq.mark_aborted()
self.waiting_list[priority].remove(seq)
elif seq.status.is_running():
self.cache_manager.free_block_table(seq.block_table)
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
self.running_list.remove(seq)
else:
try:
@@ -242,7 +278,7 @@ class RequestHandler:
else:
sample_tokens = greedy_sample(generation_config, logprobs)
else:
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_batch.is_empty)
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty)
return sample_tokens
@@ -273,27 +309,25 @@ class RequestHandler:
# sample the next tokens
sample_tokens = self._sample(probs, logprobs, generation_config)
if not self.prefill_batch.is_empty:
self.prefill_batch.update_batch_tokens(sample_tokens)
if not self.prefill_bb.is_empty:
self.prefill_bb.append_batch_tokens(sample_tokens)
else:
self.running_batch.update_batch_tokens(sample_tokens)
self.running_bb.append_batch_tokens(sample_tokens)
def update(self):
"""
Update current running list and done list
"""
if not self.prefill_batch.is_empty:
self.running_list.decoding.extend(self.running_list.prefill)
self.running_batch.add_seqs(self.running_list.prefill)
self.running_list.prefill.clear()
self.prefill_batch.clear_batch()
if not self.prefill_bb.is_empty:
self.running_list.move_prefill_to_decoding(self.prefill_bb.seqs_ids)
self.running_bb.merge(self.prefill_bb)
# clear the prefill batch without assigning a free_block_tables_fn
# since we want to reuse the memory recorded on the block tables
self.prefill_bb.clear(free_block_tables_fn=None)
finish_seqs = self.running_batch.fliter_batch()
for seq in finish_seqs:
finished_seqs, _ = self.running_bb.pop_finished(self.cache_manager.free_block_table)
for seq in finished_seqs:
self.running_list.remove(seq)
self.cache_manager.free_block_table(seq.block_table)
self.done_list.extend(finished_seqs)
self.done_list.extend(finish_seqs)
return finish_seqs
return finished_seqs