mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-07 02:54:10 +00:00
[Inference] Optimize and Refactor Inference Batching/Scheduling (#5367)
* add kvcache manager funcs for batching * add batch bucket for batching * revise RunningList struct in handler * add kvcache/batch funcs for compatibility * use new batching methods * fix indexing bugs * revise abort logic * use cpu seq lengths/block tables * rm unused attr in Sequence * fix type conversion/default arg * add and revise pytests * revise pytests, rm unused tests * rm unused statements * fix pop finished indexing issue * fix: use index in batch when retrieving inputs/update seqs * use dict instead of odict in batch struct * arg type hinting * fix make compress * refine comments * fix: pop_n_seqs to pop the first n seqs * add check in request handler * remove redundant conversion * fix test for request handler * fix pop method in batch bucket * fix prefill adding
This commit is contained in:
parent
8c69debdc7
commit
b21aac5bae
449
colossalai/inference/batch_bucket.py
Normal file
449
colossalai/inference/batch_bucket.py
Normal file
@ -0,0 +1,449 @@
|
|||||||
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.inference.struct import Sequence
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
class BatchBucket:
|
||||||
|
"""Container for a batch of Sequences, which is used to manage the batch of sequences.
|
||||||
|
|
||||||
|
Attrs:
|
||||||
|
_sequences_dict (Dict[int, Sequence]): Map sequence uid to sequence struct
|
||||||
|
seq_uid -> Sequence
|
||||||
|
_sequences_indexes (Dict[int, int]): Map sequence uid to index in the batch
|
||||||
|
seq_uid -> index in the batch (indexing used in sequence_lengths and block_tables)
|
||||||
|
_sequence_lengths (torch.Tensor): Length of each sequence in the batch.
|
||||||
|
The size of the tensor is (max_batch_size,)
|
||||||
|
_block_tables (torch.Tensor): Block table of each sequence in the batch
|
||||||
|
The size of the tensor is (max_batch_size, max_blocks_per_seq)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads,
|
||||||
|
head_dim,
|
||||||
|
max_batch_size,
|
||||||
|
max_length,
|
||||||
|
block_size,
|
||||||
|
kv_max_split_num,
|
||||||
|
fd_interm_tensor=None,
|
||||||
|
device=None,
|
||||||
|
dtype=torch.float16,
|
||||||
|
):
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.max_batch_size = max_batch_size
|
||||||
|
self.max_length = max_length # in + out len
|
||||||
|
self.block_size = block_size
|
||||||
|
self.kv_max_split_num = kv_max_split_num # Hint used for flash decoding
|
||||||
|
self.fd_interm_tensor = fd_interm_tensor
|
||||||
|
self.device = device or get_current_device()
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self._current_batch_size = 0
|
||||||
|
self._sequences_dict = dict()
|
||||||
|
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
|
||||||
|
self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32)
|
||||||
|
self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths)
|
||||||
|
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
|
||||||
|
self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32)
|
||||||
|
self._block_tables_helper = torch.full_like(self._block_tables, -1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_empty(self):
|
||||||
|
return self._current_batch_size == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_batch_size(self):
|
||||||
|
return self._current_batch_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_batch_size(self):
|
||||||
|
return self.max_batch_size - self._current_batch_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def block_tables(self):
|
||||||
|
return self._block_tables
|
||||||
|
|
||||||
|
@property
|
||||||
|
def seq_lengths(self):
|
||||||
|
return self._sequence_lengths
|
||||||
|
|
||||||
|
@property
|
||||||
|
def seqs_ids(self):
|
||||||
|
return list(self._sequences_dict.keys())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def seqs_li(self):
|
||||||
|
return list(self._sequences_dict.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_compact(self):
|
||||||
|
assert len(self._sequences_dict) == len(self._sequences_indexes), "BatchBucket indexing is not consistent"
|
||||||
|
return (
|
||||||
|
len(self._sequences_dict)
|
||||||
|
== torch.nonzero(self._sequence_lengths).view(-1).numel()
|
||||||
|
== torch.nonzero(self._block_tables[:, 0] >= 0).numel()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_compact(self) -> None:
|
||||||
|
# Clean and Compress the batch based on its sequences dict.
|
||||||
|
# Namely,compress sequences to the front and clean the seq lengths and block tables tensors.
|
||||||
|
# NOTE Prevent calling this method multiple times in a single step
|
||||||
|
if self.is_compact:
|
||||||
|
return
|
||||||
|
valid_seq_ids = self._sequences_dict.keys()
|
||||||
|
valid_num = len(valid_seq_ids)
|
||||||
|
valid_indexes = [self._sequences_indexes[seq_id] for seq_id in valid_seq_ids]
|
||||||
|
assert valid_num == len(self._sequences_indexes), "BatchBucket indexing is not consistent"
|
||||||
|
self._sequence_lengths_helper[:valid_num] = self._sequence_lengths[valid_indexes]
|
||||||
|
self._sequence_lengths[:] = self._sequence_lengths_helper[:]
|
||||||
|
self._block_tables_helper[:valid_num, :] = self.block_tables[valid_indexes]
|
||||||
|
self.block_tables[:] = self._block_tables_helper[:]
|
||||||
|
new_idx = 0
|
||||||
|
for seq_id in valid_seq_ids:
|
||||||
|
self._sequences_indexes[seq_id] = new_idx
|
||||||
|
new_idx += 1
|
||||||
|
self._sequence_lengths_helper.fill_(0)
|
||||||
|
self._block_tables_helper.fill_(-1)
|
||||||
|
self._current_batch_size = valid_num
|
||||||
|
|
||||||
|
def add_seq(
|
||||||
|
self,
|
||||||
|
seq: Sequence,
|
||||||
|
alloc_block_table: torch.Tensor = None,
|
||||||
|
alloc_block_table_fn: Callable[[torch.Tensor, int], None] = None,
|
||||||
|
) -> Union[torch.Tensor, None]:
|
||||||
|
"""Add a single sequence to the batch.
|
||||||
|
User could opt to provide either a block table or a function to allocate block tables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq (Sequence): The sequence to be added to the batch
|
||||||
|
alloc_block_table (torch.Tensor): The block tables to be copied and used for the sequence
|
||||||
|
alloc_block_table_fn (Callable[[torch.Tensor, int], None]): The function to allocate blocks for the sequence,
|
||||||
|
which is expected to reserve blocks and update status of kv-cache manager.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
block_table (torch.Tensor): The block table of the added sequence, used for block allocation in kv-cache manager.
|
||||||
|
None if the sequence cannot be added.
|
||||||
|
"""
|
||||||
|
block_table = None
|
||||||
|
# TODO might consider sorting by length
|
||||||
|
if self._current_batch_size < self.max_batch_size:
|
||||||
|
self._sequences_dict[seq.request_id] = seq
|
||||||
|
self._sequences_indexes[seq.request_id] = self._current_batch_size
|
||||||
|
self._sequence_lengths[self._current_batch_size] = seq.sentence_len
|
||||||
|
# NOTE the added seq still require block table allocation by kvcache manager
|
||||||
|
block_table = self._block_tables[self._current_batch_size - 1]
|
||||||
|
if alloc_block_table is not None:
|
||||||
|
# copy block ids from provided block tables
|
||||||
|
self._block_tables[self._current_batch_size - 1] = alloc_block_table
|
||||||
|
elif alloc_block_table_fn:
|
||||||
|
alloc_block_table_fn(block_table, self._sequence_lengths[self._current_batch_size - 1].item())
|
||||||
|
self._current_batch_size += 1
|
||||||
|
return block_table
|
||||||
|
|
||||||
|
def add_seqs(
|
||||||
|
self,
|
||||||
|
seqs: List[Sequence],
|
||||||
|
alloc_block_tables: torch.Tensor = None,
|
||||||
|
alloc_block_tables_fn: Callable[[torch.Tensor, torch.Tensor], None] = None,
|
||||||
|
) -> Union[torch.Tensor, None]:
|
||||||
|
"""Add a list of sequences to the batch.
|
||||||
|
User could opt to provide either block tables or a function to allocate block tables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seqs (List[Sequence]): The sequences to be added to the batch
|
||||||
|
alloc_block_tables (torch.Tensor): The block tables to be copied and used for the sequence
|
||||||
|
alloc_block_table_fn (Callable[[torch.Tensor, torch.Tensor], None]): The function to allocate blocks for multiple sequences,
|
||||||
|
which is expected to reserve blocks and update status of kv-cache manager.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
block_tables (torch.Tensor): The block tables of the added sequences, used for block allocation in kv-cache manager.
|
||||||
|
None if the sequences cannot be added.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert (
|
||||||
|
alloc_block_tables is None or alloc_block_tables_fn is None
|
||||||
|
), "`alloc_block_tables` and `alloc_block_tables_fn` cannot be provided at the same time"
|
||||||
|
|
||||||
|
num_seqs_to_add = min(self.max_batch_size - self._current_batch_size, len(seqs))
|
||||||
|
block_tables = None
|
||||||
|
if num_seqs_to_add > 0:
|
||||||
|
for i, seq in enumerate(seqs[:num_seqs_to_add]):
|
||||||
|
self._sequences_dict[seq.request_id] = seq
|
||||||
|
self._sequences_indexes[seq.request_id] = self._current_batch_size + i
|
||||||
|
# TODO external (rename): modify Sequence.sentence_len to seq_len
|
||||||
|
self._sequence_lengths[
|
||||||
|
self._current_batch_size : self._current_batch_size + num_seqs_to_add
|
||||||
|
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
|
||||||
|
# NOTE block tables to be updated by kvcache manager
|
||||||
|
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
|
||||||
|
if alloc_block_tables is not None:
|
||||||
|
# copy block ids from provided block tables
|
||||||
|
self._block_tables[
|
||||||
|
self._current_batch_size : self._current_batch_size + num_seqs_to_add
|
||||||
|
] = alloc_block_tables
|
||||||
|
elif alloc_block_tables_fn:
|
||||||
|
alloc_block_tables_fn(
|
||||||
|
block_tables,
|
||||||
|
self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._current_batch_size += num_seqs_to_add
|
||||||
|
seqs[:] = seqs[num_seqs_to_add:]
|
||||||
|
|
||||||
|
return block_tables
|
||||||
|
|
||||||
|
def pop_seq_update_batch(
|
||||||
|
self, request_id: int, free_block_table_fn: Callable[[torch.Tensor], None] = None
|
||||||
|
) -> Tuple[Sequence, Union[torch.Tensor, None]]:
|
||||||
|
"""Pop a single sequence by id from the batch, and update the batch bucket status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id (int): The uid of the sequence
|
||||||
|
free_block_table_fn (Callable): The function to free the block table of a sequence,
|
||||||
|
if not provided, then we have to release the block table manually after calling this method
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of: seq (Sequence): The target sequence
|
||||||
|
and block_table (torch.Tensor): block table of the target sequence indicating corresponding blocks,
|
||||||
|
none if the sequence is not found or free_block_table_fn is provided.
|
||||||
|
"""
|
||||||
|
seq: Sequence = self._sequences_dict.get(request_id)
|
||||||
|
block_table = None
|
||||||
|
if seq is not None:
|
||||||
|
assert request_id in self._sequences_indexes, "Inconsistency in BatchBucket indexing"
|
||||||
|
self._sequences_dict.pop(request_id)
|
||||||
|
seq_b_idx = self._sequences_indexes.get(request_id)
|
||||||
|
|
||||||
|
if self.current_batch_size > 1:
|
||||||
|
# replace seq length of the target seq with that of the last seq in the batch
|
||||||
|
last_seq_b_idx = self.current_batch_size - 1
|
||||||
|
last_seq_id = next(
|
||||||
|
(uid for uid, index in self._sequences_indexes.items() if index == last_seq_b_idx),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert last_seq_id is not None
|
||||||
|
self._sequences_indexes[last_seq_id] = seq_b_idx
|
||||||
|
self._sequence_lengths[seq_b_idx] = self._sequence_lengths[last_seq_b_idx]
|
||||||
|
self._sequence_lengths[last_seq_b_idx].fill_(0)
|
||||||
|
# free the block table of the seq, or return a copy of the block table (to be processed outside)
|
||||||
|
if free_block_table_fn:
|
||||||
|
free_block_table_fn(self._block_tables[seq_b_idx])
|
||||||
|
else:
|
||||||
|
block_table = self._block_tables[seq_b_idx].detach().clone()
|
||||||
|
# replace block table of the target seq with that of the last seq in the batch
|
||||||
|
self._block_tables[seq_b_idx] = self._block_tables[last_seq_b_idx]
|
||||||
|
self._block_tables[last_seq_b_idx].fill_(-1)
|
||||||
|
else:
|
||||||
|
if free_block_table_fn:
|
||||||
|
free_block_table_fn(self._block_tables[0])
|
||||||
|
else:
|
||||||
|
block_table = self._block_tables[0].detach().clone()
|
||||||
|
self._sequence_lengths[0].fill_(0)
|
||||||
|
self._block_tables[0].fill_(-1)
|
||||||
|
self._sequences_indexes.pop(request_id)
|
||||||
|
self._current_batch_size -= 1
|
||||||
|
|
||||||
|
return seq, block_table
|
||||||
|
|
||||||
|
def pop_seqs(
|
||||||
|
self, request_ids: List[int], free_block_table_fn: Callable[[torch.Tensor], None] = None
|
||||||
|
) -> Tuple[List[Sequence], List[torch.Tensor]]:
|
||||||
|
"""Iteratively pop a list of sequences by uid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_ids (List[int]): The uids of the sequences
|
||||||
|
free_block_table_fn (Callable): The function to free the block table of a sequence,
|
||||||
|
if not provided, then we have to release the block table manually after calling this method
|
||||||
|
Returns:
|
||||||
|
A tuple of: seqs (List[Sequence]): The target sequences
|
||||||
|
and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks
|
||||||
|
"""
|
||||||
|
seqs = []
|
||||||
|
block_tables = []
|
||||||
|
for request_id in request_ids:
|
||||||
|
seq, block_table = self.pop_seq_update_batch(request_id, free_block_table_fn)
|
||||||
|
if seq is not None:
|
||||||
|
seqs.append(seq)
|
||||||
|
if block_table is not None:
|
||||||
|
block_tables.append(block_table)
|
||||||
|
return seqs, block_tables
|
||||||
|
|
||||||
|
def pop_n_seqs(
|
||||||
|
self, n: int, free_block_table_fn: Callable[[torch.Tensor], None] = None
|
||||||
|
) -> Tuple[List[Sequence], List[torch.Tensor]]:
|
||||||
|
"""Pop the first n sequences in the batch (FIFO).
|
||||||
|
If n is greater than the current batch szie, pop all the sequences in the batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (int): The number of sequences to pop out
|
||||||
|
free_block_table_fn (Callable): The function to free the block table of a single sequence
|
||||||
|
Returns:
|
||||||
|
A tuple of: seqs (List[Sequence]): The target sequences,
|
||||||
|
and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks
|
||||||
|
"""
|
||||||
|
# NOTE Prevent calling this method multiple times in a single step
|
||||||
|
seqs = []
|
||||||
|
block_tables = []
|
||||||
|
n = min(n, self.current_batch_size)
|
||||||
|
seq_ids = list(self._sequences_dict.keys())[:n]
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
seq = self._sequences_dict.pop(seq_id)
|
||||||
|
seq_b_idx = self._sequences_indexes.pop(seq_id)
|
||||||
|
if free_block_table_fn:
|
||||||
|
free_block_table_fn(self.block_tables[seq_b_idx])
|
||||||
|
else:
|
||||||
|
block_tables.append(self.block_tables[seq_b_idx].detach().clone())
|
||||||
|
seqs.append(seq)
|
||||||
|
if not self.is_compact:
|
||||||
|
self._make_compact()
|
||||||
|
return seqs, block_tables
|
||||||
|
|
||||||
|
def pop_finished(
|
||||||
|
self, free_block_table_fn: Callable[[torch.Tensor], None] = None
|
||||||
|
) -> Tuple[List[Sequence], List[torch.Tensor]]:
|
||||||
|
"""Pop finished sequences in the batch and a list of block tables of the finished sequences,
|
||||||
|
if free_block_table_fn is not provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
free_block_table_fn (Callable): The function to free the block table of a single sequence
|
||||||
|
Returns:
|
||||||
|
A tuple of: finished_seqs (List[Sequence]): The finished sequences,
|
||||||
|
and finished_block_tables (List[torch.Tensor]): block tables of the finished sequences.
|
||||||
|
"""
|
||||||
|
finished_seqs = []
|
||||||
|
finished_block_tables = []
|
||||||
|
for seq in self._sequences_dict.values():
|
||||||
|
if seq.check_finish():
|
||||||
|
finished_seqs.append(seq)
|
||||||
|
# Use `pop_seq_update_batch`` to update the batch status for just a few of finished seqs,
|
||||||
|
# otherwise, pop seqs directly and then call `_make_compact` to compress the batch.
|
||||||
|
# For now, the performance difference is not significant, so we use the frist method to pop seqs.
|
||||||
|
# Precise evaluations to be done.
|
||||||
|
for seq in finished_seqs:
|
||||||
|
_, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn)
|
||||||
|
if block_table is not None:
|
||||||
|
finished_block_tables.append(block_table)
|
||||||
|
|
||||||
|
return finished_seqs, finished_block_tables
|
||||||
|
|
||||||
|
# TODO arg type not support beam search sampling yet
|
||||||
|
def append_batch_tokens(self, tokens: torch.Tensor) -> None:
|
||||||
|
"""Append a batch of tokens to the sequences in the batch"""
|
||||||
|
assert self.current_batch_size == tokens.size(0), "Batch size mismatch"
|
||||||
|
|
||||||
|
if self.current_batch_size > 0:
|
||||||
|
tokens = tokens.tolist()
|
||||||
|
for seq_id, seq in self._sequences_dict.items():
|
||||||
|
index_in_b = self._sequences_indexes[seq_id]
|
||||||
|
curr_tokens = tokens[index_in_b]
|
||||||
|
if not isinstance(curr_tokens, list):
|
||||||
|
curr_tokens = [curr_tokens]
|
||||||
|
seq.output_token_id += curr_tokens
|
||||||
|
seq.check_finish()
|
||||||
|
self._sequence_lengths[: self.current_batch_size] += 1
|
||||||
|
|
||||||
|
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
|
||||||
|
"""Clear all the sequences in the batch.
|
||||||
|
|
||||||
|
free_block_tables_fn (Optional[Callable]): The function to free the block tables of all the sequences in a batch
|
||||||
|
"""
|
||||||
|
seqs = list(self._sequences_dict.values())
|
||||||
|
self._sequences_dict.clear()
|
||||||
|
self._sequences_indexes.clear()
|
||||||
|
if free_block_tables_fn:
|
||||||
|
free_block_tables_fn(self.block_tables, self._current_batch_size)
|
||||||
|
self._block_tables.fill_(-1)
|
||||||
|
self._sequence_lengths.fill_(0)
|
||||||
|
self._current_batch_size = 0
|
||||||
|
return seqs
|
||||||
|
|
||||||
|
def merge(self, other: "BatchBucket") -> List[int]:
|
||||||
|
"""Merge the sequences in the other batch into the current batch.
|
||||||
|
Merge as possible as the current batch can, if it does not have available spaces
|
||||||
|
holding all the sequences in the other batch
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
> New incoming sequence added to prefil batch
|
||||||
|
prefill bb curr batch size < prefil_ratio * prefill bb max batch size
|
||||||
|
> New incoming sequence added to prefil batch
|
||||||
|
prefill bb curr batch size == prefil_ratio * prefill bb max batch size
|
||||||
|
> Pause Decoding
|
||||||
|
> Prefill
|
||||||
|
> Move sequences in prefill bb => decoding bb
|
||||||
|
> Put back the out-of-volume sequences into the running pool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
unmerged_ids (List[int]): a list of sequence uids that are not merged into the current batch
|
||||||
|
"""
|
||||||
|
unmerged_ids = []
|
||||||
|
num_seqs_to_merge = min(self.available_batch_size, other.current_batch_size)
|
||||||
|
if num_seqs_to_merge > 0:
|
||||||
|
seqs, block_tables_li = other.pop_n_seqs(num_seqs_to_merge)
|
||||||
|
block_tables = torch.stack(block_tables_li)
|
||||||
|
self.add_seqs(seqs, alloc_block_tables=block_tables)
|
||||||
|
unmerged_ids = other.seqs_ids
|
||||||
|
return unmerged_ids
|
||||||
|
|
||||||
|
########## The following methods are expected to be used in modeling ###########
|
||||||
|
|
||||||
|
# For compatibility.
|
||||||
|
# NOTE: This is an assumption way to determine the stage of the batch.
|
||||||
|
@property
|
||||||
|
def is_prompts(self) -> bool:
|
||||||
|
assert len(self._sequences_dict) > 0, "No sequence in the batch"
|
||||||
|
first_seq = next(iter(self._sequences_dict.values()))
|
||||||
|
if first_seq.output_len == 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# For compatibility
|
||||||
|
def get_1D_inputs(self) -> torch.Tensor:
|
||||||
|
assert len(self._sequences_dict) > 0, "No sequence in the batch"
|
||||||
|
first_seq = next(iter(self._sequences_dict.values())) # not exactly the first sequence
|
||||||
|
if first_seq.output_len == 0:
|
||||||
|
# Assume prefill stage
|
||||||
|
assert all(
|
||||||
|
seq.output_len == 0 for seq in self._sequences_dict.values()
|
||||||
|
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
|
||||||
|
out_li = []
|
||||||
|
num_tokens = torch.sum(self._sequence_lengths)
|
||||||
|
out = torch.empty([num_tokens], dtype=torch.long)
|
||||||
|
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
seq: Sequence = self._sequences_dict[seq_id]
|
||||||
|
out_li.extend(seq.input_token_id)
|
||||||
|
return torch.tensor(out_li, dtype=torch.long, device=self.device)
|
||||||
|
else:
|
||||||
|
# Assume decoding stage
|
||||||
|
assert all(
|
||||||
|
seq.output_len > 0 for seq in self._sequences_dict.values()
|
||||||
|
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
|
||||||
|
assert self.is_compact, "BatchBucket is not compact"
|
||||||
|
out = torch.empty([self.current_batch_size], dtype=torch.long)
|
||||||
|
for seq_id, index_in_b in self._sequences_indexes.items():
|
||||||
|
seq: Sequence = self._sequences_dict[seq_id]
|
||||||
|
out[index_in_b] = seq.output_token_id[-1]
|
||||||
|
return out.to(device=self.device)
|
||||||
|
|
||||||
|
# For compatibility
|
||||||
|
def get_block_table_tensor(self) -> torch.Tensor:
|
||||||
|
assert self.is_compact # Debug usage
|
||||||
|
block_table = self.block_tables[: self.current_batch_size]
|
||||||
|
return block_table.to(device=self.device)
|
||||||
|
|
||||||
|
# For compatibility
|
||||||
|
def get_sequence_lengths(self) -> torch.Tensor:
|
||||||
|
assert self.is_compact # Debug usage
|
||||||
|
sequence_lengths = self.seq_lengths[: self.current_batch_size]
|
||||||
|
return sequence_lengths.to(device=self.device)
|
||||||
|
|
||||||
|
# For compatibility
|
||||||
|
@property
|
||||||
|
def fd_inter_tensor(self) -> None:
|
||||||
|
assert self.fd_interm_tensor is not None, "fd_interm_tensor is not provided"
|
||||||
|
return self.fd_interm_tensor
|
@ -109,7 +109,7 @@ class InferenceConfig:
|
|||||||
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
|
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
|
||||||
|
|
||||||
# check distributed
|
# check distributed
|
||||||
assert (
|
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
|
||||||
self.tp_size * self.pp_size == dist.get_world_size()
|
self.tp_size * self.pp_size == dist.get_world_size()
|
||||||
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||||
# check prompt template
|
# check prompt template
|
||||||
|
@ -42,7 +42,7 @@ class InferenceEngine:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: [Union[PreTrainedTokenizer, PreTrainedTokenizerFast]],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
inference_config: InferenceConfig,
|
inference_config: InferenceConfig,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
model_policy: Policy = None,
|
model_policy: Policy = None,
|
||||||
@ -254,20 +254,12 @@ class InferenceEngine:
|
|||||||
else:
|
else:
|
||||||
prompt = prompts[i]
|
prompt = prompts[i]
|
||||||
|
|
||||||
max_blocks_per_sequence = (
|
|
||||||
self.inference_config.max_input_len
|
|
||||||
+ self.inference_config.max_output_len
|
|
||||||
+ self.inference_config.block_size
|
|
||||||
- 1
|
|
||||||
) // self.inference_config.block_size
|
|
||||||
block_table = torch.full([max_blocks_per_sequence], -1, device=self.device)
|
|
||||||
sequence = Sequence(
|
sequence = Sequence(
|
||||||
request_id,
|
request_id,
|
||||||
prompt,
|
prompt,
|
||||||
prompts_token_ids[i],
|
prompts_token_ids[i],
|
||||||
block_size,
|
block_size,
|
||||||
None,
|
None,
|
||||||
block_table,
|
|
||||||
self.tokenizer.eos_token_id,
|
self.tokenizer.eos_token_id,
|
||||||
self.tokenizer.pad_token_id,
|
self.tokenizer.pad_token_id,
|
||||||
self.inference_config.max_output_len,
|
self.inference_config.max_output_len,
|
||||||
|
@ -1,15 +1,16 @@
|
|||||||
from typing import List
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.generation import GenerationConfig
|
from transformers.generation import GenerationConfig
|
||||||
|
|
||||||
|
from colossalai.inference.batch_bucket import BatchBucket
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
from colossalai.inference.kv_cache import KVCacheManager
|
from colossalai.inference.kv_cache import KVCacheManager
|
||||||
from colossalai.inference.logit_processors import logit_processor
|
from colossalai.inference.logit_processors import logit_processor
|
||||||
from colossalai.inference.sampler import *
|
from colossalai.inference.sampler import *
|
||||||
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
|
from colossalai.inference.struct import RequestStatus, Sequence
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
__all__ = ["RunningList", "RequestHandler"]
|
__all__ = ["RunningList", "RequestHandler"]
|
||||||
@ -24,45 +25,79 @@ class RunningList:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
prefill_ratio: (float) A ratio for determing whether to perform prefill or not.
|
prefill_ratio: (float) A ratio for determing whether to perform prefill or not.
|
||||||
prefill: (List) List that contains default inputs, defaults to [].
|
_prefill (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
|
||||||
|
_decoding (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prefill_ratio: str, prefill: List[Sequence] = None):
|
def __init__(self, prefill_ratio: int, prefill: List[Sequence] = None) -> None:
|
||||||
self.prefill_ratio = prefill_ratio
|
self.prefill_ratio = prefill_ratio
|
||||||
self.decoding: List[Sequence] = []
|
self._decoding: Dict[int, Sequence] = dict()
|
||||||
self.prefill: List[Sequence] = prefill if prefill is not None else []
|
self._prefill: Dict[int, Sequence] = (
|
||||||
|
dict({seq.request_id: seq for seq in self._prefill}) if prefill is not None else dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decoding(self):
|
||||||
|
return list(self._decoding.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefill(self):
|
||||||
|
return list(self._prefill.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefill_seq_num(self):
|
||||||
|
return len(self._prefill)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decoding_seq_num(self):
|
||||||
|
return len(self._decoding)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_seq_num(self):
|
||||||
|
return self.prefill_seq_num + self.decoding_seq_num
|
||||||
|
|
||||||
def append(self, seq: Sequence):
|
def append(self, seq: Sequence):
|
||||||
# add seq to prefilling list first.
|
assert (seq.request_id not in self._prefill) and (
|
||||||
self.prefill.append(seq)
|
seq.request_id not in self._decoding
|
||||||
|
), f"Sequence uid {seq.request_id} already exists."
|
||||||
|
self._prefill[seq.request_id] = seq
|
||||||
|
|
||||||
def find_seq(self, request_id):
|
def extend(self, seqs: List[Sequence]):
|
||||||
for seq in self.decoding:
|
for seq in seqs:
|
||||||
if request_id == seq.request_id:
|
self._prefill[seq.request_id] = seq
|
||||||
return seq
|
|
||||||
for seq in self.prefill:
|
|
||||||
if request_id == seq.request_id:
|
|
||||||
return seq
|
|
||||||
return None
|
|
||||||
|
|
||||||
def remove(self, seq: Sequence):
|
def find_seq(self, request_id) -> Union[Sequence, None]:
|
||||||
if seq in self.decoding:
|
seq = None
|
||||||
self.decoding.remove(seq)
|
if request_id in self._decoding:
|
||||||
elif seq in self.prefill:
|
seq = self._decoding[request_id]
|
||||||
self.prefill.remove(seq)
|
elif request_id in self._prefill:
|
||||||
|
seq = self._prefill[request_id]
|
||||||
|
return seq
|
||||||
|
|
||||||
|
def remove(self, seq: Sequence) -> None:
|
||||||
|
if seq.request_id in self._decoding:
|
||||||
|
self._decoding.pop(seq.request_id)
|
||||||
|
elif seq.request_id in self._prefill:
|
||||||
|
self._prefill.pop(seq.request_id)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"sequence {seq.request_id} is not in running list")
|
raise ValueError(f"Sequence {seq.request_id} is not in running list")
|
||||||
|
|
||||||
def ready_for_prefill(self):
|
def ready_for_prefill(self):
|
||||||
if not self.decoding:
|
if not self._decoding:
|
||||||
return len(self.prefill) > 0
|
return len(self._prefill) > 0
|
||||||
return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
|
return len(self._prefill) / len(self._decoding) >= self.prefill_ratio
|
||||||
|
|
||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return not self.decoding and not self.prefill
|
return not self._decoding and not self._prefill
|
||||||
|
|
||||||
def total_seq_num(self):
|
def mark_prefill_running(self) -> None:
|
||||||
return len(self.decoding) + len(self.prefill)
|
for seq_id in self._prefill:
|
||||||
|
self._prefill[seq_id].mark_running()
|
||||||
|
|
||||||
|
def move_prefill_to_decoding(self, seq_ids: List[int]) -> None:
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
assert seq_id in self._prefill, f"Sequence {seq_id} is not in prefill list"
|
||||||
|
self._decoding[seq_id] = self._prefill.pop(seq_id)
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler:
|
class RequestHandler:
|
||||||
@ -110,25 +145,27 @@ class RequestHandler:
|
|||||||
|
|
||||||
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
||||||
# which may cause bugs and this issue should be fixed later.
|
# which may cause bugs and this issue should be fixed later.
|
||||||
self.running_batch = BatchInfo(
|
self.running_bb = BatchBucket(
|
||||||
max_batch_size=self.max_batch_size,
|
|
||||||
kv_max_split_num=kv_max_split_num,
|
|
||||||
num_heads=model_config.num_attention_heads,
|
num_heads=model_config.num_attention_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
is_prompts=False,
|
max_batch_size=self.max_batch_size,
|
||||||
device=device,
|
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||||
|
block_size=inference_config.block_size,
|
||||||
|
kv_max_split_num=kv_max_split_num,
|
||||||
|
fd_interm_tensor=fd_inter_tensor,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
fd_inter_tensor=fd_inter_tensor,
|
device=device,
|
||||||
)
|
)
|
||||||
self.prefill_batch = BatchInfo(
|
self.prefill_bb = BatchBucket(
|
||||||
max_batch_size=self.max_batch_size,
|
|
||||||
kv_max_split_num=kv_max_split_num,
|
|
||||||
num_heads=model_config.num_attention_heads,
|
num_heads=model_config.num_attention_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
is_prompts=True,
|
max_batch_size=self.max_batch_size,
|
||||||
device=device,
|
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||||
|
block_size=inference_config.block_size,
|
||||||
|
kv_max_split_num=kv_max_split_num,
|
||||||
|
fd_interm_tensor=fd_inter_tensor,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
fd_inter_tensor=fd_inter_tensor,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_cache(self, model_config):
|
def _init_cache(self, model_config):
|
||||||
@ -159,40 +196,39 @@ class RequestHandler:
|
|||||||
remove_list.append(seq)
|
remove_list.append(seq)
|
||||||
break
|
break
|
||||||
|
|
||||||
# stop feeding new sequence into running list to assure
|
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
|
||||||
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num():
|
remove_list.extend(lst[:num_seqs_to_add])
|
||||||
break
|
self.running_list.extend(lst[:num_seqs_to_add])
|
||||||
|
|
||||||
# Try to allocate cache blocks for the sequence.
|
|
||||||
if (
|
|
||||||
self.cache_manager.check_allocation(seq)
|
|
||||||
and (len(self.running_list.prefill) + len(self.running_list.decoding))
|
|
||||||
< self.max_batch_size # There some bugs in continous batching, so we disable it here.
|
|
||||||
):
|
|
||||||
# If succeed, add the sequence to running list.
|
|
||||||
remove_list.append(seq)
|
|
||||||
self.running_list.append(seq)
|
|
||||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len)
|
|
||||||
for seq in remove_list:
|
for seq in remove_list:
|
||||||
lst.remove(seq)
|
lst.remove(seq)
|
||||||
|
|
||||||
if self.running_list.ready_for_prefill():
|
if self.running_list.ready_for_prefill():
|
||||||
for seq in self.running_list.prefill:
|
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.running_bb.available_batch_size)
|
||||||
seq.mark_running()
|
|
||||||
self.prefill_batch.add_seqs(self.running_list.prefill)
|
|
||||||
return self.prefill_batch
|
|
||||||
|
|
||||||
if not self.running_batch.is_empty:
|
for seq in self.running_list.prefill[:num_seqs_to_add]:
|
||||||
for seq in self.running_batch.sequences_set:
|
seq.mark_running()
|
||||||
recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
|
# allocate blocks for the prefill batch
|
||||||
if recycle:
|
self.prefill_bb.add_seqs(
|
||||||
|
self.running_list.prefill[:num_seqs_to_add],
|
||||||
|
alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.prefill_bb
|
||||||
|
|
||||||
|
if not self.running_bb.is_empty:
|
||||||
|
seqs_ids_to_recycle = self.cache_manager.allocate_tokens_from_block_tables(
|
||||||
|
self.running_bb.block_tables, self.running_bb.seq_lengths, self.running_bb.current_batch_size
|
||||||
|
)
|
||||||
|
if seqs_ids_to_recycle:
|
||||||
|
seqs_to_recycle = self.running_bb.pop_seqs(seqs_ids_to_recycle)
|
||||||
|
for seq in seqs_to_recycle:
|
||||||
seq.recycle()
|
seq.recycle()
|
||||||
self.running_batch.del_seq(seq)
|
|
||||||
self.running_list.remove(seq)
|
self.running_list.remove(seq)
|
||||||
self.waiting_list[-1].append(seq)
|
self.waiting_list[-1].append(seq)
|
||||||
# the recycled sequences are handled with highest priority.
|
# the recycled sequences are handled with highest priority.
|
||||||
|
|
||||||
return self.running_batch
|
return self.running_bb
|
||||||
|
|
||||||
def add_sequence(self, req: Sequence):
|
def add_sequence(self, req: Sequence):
|
||||||
"""
|
"""
|
||||||
@ -213,7 +249,7 @@ class RequestHandler:
|
|||||||
seq.mark_aborted()
|
seq.mark_aborted()
|
||||||
self.waiting_list[priority].remove(seq)
|
self.waiting_list[priority].remove(seq)
|
||||||
elif seq.status.is_running():
|
elif seq.status.is_running():
|
||||||
self.cache_manager.free_block_table(seq.block_table)
|
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
|
||||||
self.running_list.remove(seq)
|
self.running_list.remove(seq)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
@ -242,7 +278,7 @@ class RequestHandler:
|
|||||||
else:
|
else:
|
||||||
sample_tokens = greedy_sample(generation_config, logprobs)
|
sample_tokens = greedy_sample(generation_config, logprobs)
|
||||||
else:
|
else:
|
||||||
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_batch.is_empty)
|
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty)
|
||||||
|
|
||||||
return sample_tokens
|
return sample_tokens
|
||||||
|
|
||||||
@ -273,27 +309,25 @@ class RequestHandler:
|
|||||||
|
|
||||||
# sample the next tokens
|
# sample the next tokens
|
||||||
sample_tokens = self._sample(probs, logprobs, generation_config)
|
sample_tokens = self._sample(probs, logprobs, generation_config)
|
||||||
if not self.prefill_batch.is_empty:
|
if not self.prefill_bb.is_empty:
|
||||||
self.prefill_batch.update_batch_tokens(sample_tokens)
|
self.prefill_bb.append_batch_tokens(sample_tokens)
|
||||||
else:
|
else:
|
||||||
self.running_batch.update_batch_tokens(sample_tokens)
|
self.running_bb.append_batch_tokens(sample_tokens)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
"""
|
"""
|
||||||
Update current running list and done list
|
Update current running list and done list
|
||||||
"""
|
"""
|
||||||
if not self.prefill_batch.is_empty:
|
if not self.prefill_bb.is_empty:
|
||||||
self.running_list.decoding.extend(self.running_list.prefill)
|
self.running_list.move_prefill_to_decoding(self.prefill_bb.seqs_ids)
|
||||||
self.running_batch.add_seqs(self.running_list.prefill)
|
self.running_bb.merge(self.prefill_bb)
|
||||||
self.running_list.prefill.clear()
|
# clear the prefill batch without assigning a free_block_tables_fn
|
||||||
self.prefill_batch.clear_batch()
|
# since we want to reuse the memory recorded on the block tables
|
||||||
|
self.prefill_bb.clear(free_block_tables_fn=None)
|
||||||
|
|
||||||
finish_seqs = self.running_batch.fliter_batch()
|
finished_seqs, _ = self.running_bb.pop_finished(self.cache_manager.free_block_table)
|
||||||
|
for seq in finished_seqs:
|
||||||
for seq in finish_seqs:
|
|
||||||
self.running_list.remove(seq)
|
self.running_list.remove(seq)
|
||||||
self.cache_manager.free_block_table(seq.block_table)
|
self.done_list.extend(finished_seqs)
|
||||||
|
|
||||||
self.done_list.extend(finish_seqs)
|
return finished_seqs
|
||||||
|
|
||||||
return finish_seqs
|
|
||||||
|
@ -63,7 +63,6 @@ class KVCacheManager:
|
|||||||
self.dtype = config.dtype
|
self.dtype = config.dtype
|
||||||
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
|
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
|
||||||
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
|
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
|
||||||
# For now we focus on MHA only, TODO add handling for MQA and GQA
|
|
||||||
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
|
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
|
||||||
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
|
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
|
||||||
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
|
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
|
||||||
@ -82,8 +81,8 @@ class KVCacheManager:
|
|||||||
|
|
||||||
# Physical cache allocation
|
# Physical cache allocation
|
||||||
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size)
|
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size)
|
||||||
if verbose:
|
# if verbose:
|
||||||
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
|
# self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
|
||||||
self._kv_caches = self._init_device_caches(alloc_shape)
|
self._kv_caches = self._init_device_caches(alloc_shape)
|
||||||
self.total_physical_cache_size_in_bytes = (
|
self.total_physical_cache_size_in_bytes = (
|
||||||
self.elem_size_in_bytes
|
self.elem_size_in_bytes
|
||||||
@ -112,6 +111,9 @@ class KVCacheManager:
|
|||||||
"""Get the number of available cache blocks."""
|
"""Get the number of available cache blocks."""
|
||||||
return self._available_blocks
|
return self._available_blocks
|
||||||
|
|
||||||
|
def get_head_size(self):
|
||||||
|
return self.head_size
|
||||||
|
|
||||||
def get_kv_cache(self):
|
def get_kv_cache(self):
|
||||||
"""Get k_cache and v_cache"""
|
"""Get k_cache and v_cache"""
|
||||||
return self._kv_caches
|
return self._kv_caches
|
||||||
@ -148,7 +150,7 @@ class KVCacheManager:
|
|||||||
and updates the provided block table with the allocated block ids.
|
and updates the provided block table with the allocated block ids.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
|
block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id.
|
||||||
context_len: The length of the processing sequnece.
|
context_len: The length of the processing sequnece.
|
||||||
"""
|
"""
|
||||||
assert block_table.dim() == 1
|
assert block_table.dim() == 1
|
||||||
@ -193,12 +195,85 @@ class KVCacheManager:
|
|||||||
else:
|
else:
|
||||||
self._allocate_on_block(block, block.block_size)
|
self._allocate_on_block(block, block.block_size)
|
||||||
|
|
||||||
|
def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context_lengths: torch.Tensor) -> None:
|
||||||
|
"""Allocate logical cache blocks for a batch of sequences during prefill stage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence]
|
||||||
|
context_lengths (torch.Tensor): [bsz]]
|
||||||
|
"""
|
||||||
|
assert block_tables.dim() == 2
|
||||||
|
assert block_tables.size(0) == context_lengths.size(0)
|
||||||
|
if not torch.all(block_tables < 0):
|
||||||
|
self.logger.error("Some slots on provided block table have been allocated.")
|
||||||
|
blocks_required = (context_lengths + self.block_size - 1) // self.block_size
|
||||||
|
num_blocks_required = torch.sum(blocks_required).item()
|
||||||
|
assert isinstance(num_blocks_required, int)
|
||||||
|
if num_blocks_required > self._available_blocks:
|
||||||
|
self.logger.warning(
|
||||||
|
f"Lacking blocks to allocate. Available blocks {self._available_blocks}; blocks asked {num_blocks_required}."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
bsz = block_tables.size(0)
|
||||||
|
# Try contiguous allocation
|
||||||
|
torch.cumsum(self._block_states, dim=-1, out=self._block_states_cum[1:])
|
||||||
|
torch.subtract(
|
||||||
|
self._block_states_cum[num_blocks_required:],
|
||||||
|
self._block_states_cum[:-num_blocks_required],
|
||||||
|
out=self._block_finder[num_blocks_required - 1 :],
|
||||||
|
)
|
||||||
|
end_indexes = torch.nonzero(self._block_finder == num_blocks_required, as_tuple=False).view(-1)
|
||||||
|
if end_indexes.numel() > 0:
|
||||||
|
# contiguous cache exists
|
||||||
|
end_idx = end_indexes[0].item() + 1 # open interval
|
||||||
|
start_idx = end_idx - num_blocks_required # closed interval
|
||||||
|
alloc_block_ids = torch.arange(start_idx, end_idx)
|
||||||
|
for i in range(bsz):
|
||||||
|
curr_required = blocks_required[i]
|
||||||
|
block_tables[i, :curr_required] = torch.arange(
|
||||||
|
start_idx, start_idx + curr_required, device=block_tables.device
|
||||||
|
)
|
||||||
|
start_idx += curr_required
|
||||||
|
else:
|
||||||
|
# non-contiguous cache
|
||||||
|
available_block_ids = torch.nonzero(self._block_states > 0).view(-1)
|
||||||
|
alloc_block_ids = available_block_ids[:num_blocks_required]
|
||||||
|
alloc_block_ids = alloc_block_ids.to(dtype=block_tables.dtype, device=block_tables.device)
|
||||||
|
start_idx = 0
|
||||||
|
for i in range(bsz):
|
||||||
|
curr_required = blocks_required[i]
|
||||||
|
block_tables[i, :curr_required] = alloc_block_ids[start_idx, start_idx + curr_required]
|
||||||
|
start_idx += curr_required
|
||||||
|
|
||||||
|
# Update cache blocks
|
||||||
|
self._block_states[alloc_block_ids] = 0
|
||||||
|
self._available_blocks -= num_blocks_required
|
||||||
|
last_block_locs = torch.cumsum(blocks_required, dim=0) - 1
|
||||||
|
last_block_locs = last_block_locs.to(device=alloc_block_ids.device)
|
||||||
|
|
||||||
|
for i, block_id in enumerate(alloc_block_ids[last_block_locs]):
|
||||||
|
block: CacheBlock = self._cache_blocks[block_id]
|
||||||
|
block.add_ref()
|
||||||
|
self._allocate_on_block(
|
||||||
|
block,
|
||||||
|
block.block_size
|
||||||
|
if context_lengths[i] % block.block_size == 0
|
||||||
|
else context_lengths[i].item() % block.block_size,
|
||||||
|
)
|
||||||
|
for block_id in alloc_block_ids:
|
||||||
|
if block_id in alloc_block_ids[last_block_locs]:
|
||||||
|
continue
|
||||||
|
block: CacheBlock = self._cache_blocks[block_id]
|
||||||
|
block.add_ref()
|
||||||
|
self._allocate_on_block(block, block.block_size)
|
||||||
|
|
||||||
def allocate_token_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None:
|
def allocate_token_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None:
|
||||||
"""Allocate the logical cache block for a single sequence during decoding stage,
|
"""Allocate the logical cache block for a single sequence during decoding stage,
|
||||||
and updates the provided block table if a new cache block is needed.
|
and updates the provided block table if a new cache block is needed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
|
block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id.
|
||||||
context_len: The length of the processing sequnece (already-allocated length).
|
context_len: The length of the processing sequnece (already-allocated length).
|
||||||
"""
|
"""
|
||||||
assert block_table.dim() == 1
|
assert block_table.dim() == 1
|
||||||
@ -207,12 +282,79 @@ class KVCacheManager:
|
|||||||
alloc_local_block_idx = context_len // self.block_size
|
alloc_local_block_idx = context_len // self.block_size
|
||||||
return self.allocate_single_block(block_table, alloc_local_block_idx)
|
return self.allocate_single_block(block_table, alloc_local_block_idx)
|
||||||
|
|
||||||
|
def allocate_tokens_from_block_tables(
|
||||||
|
self, block_tables: torch.Tensor, context_lens: torch.Tensor, bsz: int = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""Allocate logical cache blocks for a batch of sequences during decoding stage.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
allocate_context_from_block_tables
|
||||||
|
model forward (block tables & context lengths passed)
|
||||||
|
update context lengths
|
||||||
|
allocate_tokens_from_block_tables
|
||||||
|
model forward
|
||||||
|
update context lengths
|
||||||
|
allocate_tokens_from_block_tables
|
||||||
|
model forward
|
||||||
|
update context lengths
|
||||||
|
...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence]
|
||||||
|
context_lengths (torch.Tensor): [bsz]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: list of sequence uid to be recycled
|
||||||
|
"""
|
||||||
|
assert block_tables.dim() == 2
|
||||||
|
assert context_lens.dim() == 1
|
||||||
|
|
||||||
|
bsz = block_tables.size(0) if bsz is None else bsz
|
||||||
|
|
||||||
|
alloc_local_block_indexes = (context_lens[:bsz]) // self.block_size
|
||||||
|
block_global_ids = block_tables[torch.arange(0, bsz), alloc_local_block_indexes]
|
||||||
|
seqs_to_recycle = []
|
||||||
|
new_blocks_required = torch.sum(block_global_ids < 0).item()
|
||||||
|
seqs_req_new_blocks = torch.nonzero(block_global_ids < 0).squeeze()
|
||||||
|
|
||||||
|
if new_blocks_required > 0:
|
||||||
|
if new_blocks_required > self._available_blocks:
|
||||||
|
# TODO might want to revise the logic here
|
||||||
|
# Process the first (_available_blocks) sequences that require new blocks
|
||||||
|
# Put the rest of the sequences back to recycled
|
||||||
|
seqs_req_new_blocks, seqs_to_recycle = (
|
||||||
|
seqs_req_new_blocks[: self._available_blocks],
|
||||||
|
seqs_req_new_blocks[self._available_blocks :],
|
||||||
|
)
|
||||||
|
for seq_id in seqs_to_recycle:
|
||||||
|
self.free_block_table(block_tables[seq_id])
|
||||||
|
new_blocks_required = self._available_blocks
|
||||||
|
|
||||||
|
# NOTE might want to alloc contiguous logic
|
||||||
|
free_block_ids = torch.nonzero(self._block_states > 0).view(-1)
|
||||||
|
alloc_block_ids = free_block_ids[:new_blocks_required].to(
|
||||||
|
dtype=block_tables.dtype, device=block_tables.device
|
||||||
|
)
|
||||||
|
|
||||||
|
for block_id in alloc_block_ids:
|
||||||
|
block: CacheBlock = self._cache_blocks[block_id]
|
||||||
|
block.add_ref()
|
||||||
|
self._block_states[block_id] = 0
|
||||||
|
self._available_blocks -= 1
|
||||||
|
block_tables[seqs_req_new_blocks, alloc_local_block_indexes[seqs_req_new_blocks]] = alloc_block_ids
|
||||||
|
block_global_ids = block_tables[torch.arange(0, bsz), alloc_local_block_indexes]
|
||||||
|
|
||||||
|
for block_id in block_global_ids:
|
||||||
|
self._allocate_on_block(self._cache_blocks[block_id], 1)
|
||||||
|
|
||||||
|
return seqs_to_recycle
|
||||||
|
|
||||||
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int:
|
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int:
|
||||||
"""Allocate space asked on a single block in the block table, specified by the provided position id,
|
"""Allocate space asked on a single block in the block table, specified by the provided position id,
|
||||||
and updates the provided block table with the allocated block.
|
and updates the provided block table with the allocated block.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
|
block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id.
|
||||||
block_local_idx: The index of the block in the block table.
|
block_local_idx: The index of the block in the block table.
|
||||||
space_asked: i.e. The number of tokens to be assigned space for.
|
space_asked: i.e. The number of tokens to be assigned space for.
|
||||||
Returns:
|
Returns:
|
||||||
@ -240,8 +382,7 @@ class KVCacheManager:
|
|||||||
def free_block_table(self, block_table: torch.Tensor) -> None:
|
def free_block_table(self, block_table: torch.Tensor) -> None:
|
||||||
"""Free the logical cache blocks for **a single sequence**."""
|
"""Free the logical cache blocks for **a single sequence**."""
|
||||||
assert block_table.dim() == 1
|
assert block_table.dim() == 1
|
||||||
for i in range(block_table.numel()):
|
for i, global_block_id in enumerate(block_table.tolist()):
|
||||||
global_block_id = block_table[i].item()
|
|
||||||
if global_block_id < 0:
|
if global_block_id < 0:
|
||||||
return
|
return
|
||||||
block: CacheBlock = self._cache_blocks[global_block_id]
|
block: CacheBlock = self._cache_blocks[global_block_id]
|
||||||
@ -253,6 +394,15 @@ class KVCacheManager:
|
|||||||
# reset the block id in the block table (if we maintain a 2D tensors as block tables in Engine)
|
# reset the block id in the block table (if we maintain a 2D tensors as block tables in Engine)
|
||||||
block_table[i] = -1
|
block_table[i] = -1
|
||||||
|
|
||||||
|
def free_block_tables(self, block_tables: torch.Tensor, first_n: int = None) -> None:
|
||||||
|
"""Release the logical cache blocks for a batch of sequences.
|
||||||
|
If `first_n` is provided, only the blocks for the first several sequences will be released.
|
||||||
|
"""
|
||||||
|
assert block_tables.dim() == 2
|
||||||
|
first_n = block_tables.size(0) if first_n is None else first_n
|
||||||
|
for block_table in block_tables[:first_n]:
|
||||||
|
self.free_block_table(block_table)
|
||||||
|
|
||||||
def clear_all(self) -> None:
|
def clear_all(self) -> None:
|
||||||
"""Clear all the references and allocations on all the cache blocks."""
|
"""Clear all the references and allocations on all the cache blocks."""
|
||||||
for block in self._cache_blocks:
|
for block in self._cache_blocks:
|
||||||
|
@ -12,8 +12,8 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
LlamaModel,
|
LlamaModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from colossalai.inference.batch_bucket import BatchBucket
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
from colossalai.inference.struct import BatchInfo
|
|
||||||
from colossalai.kernel.triton import (
|
from colossalai.kernel.triton import (
|
||||||
context_attention_unpadded,
|
context_attention_unpadded,
|
||||||
copy_kv_to_blocked_cache,
|
copy_kv_to_blocked_cache,
|
||||||
@ -34,7 +34,7 @@ except ImportError:
|
|||||||
|
|
||||||
def llama_causal_lm_forward(
|
def llama_causal_lm_forward(
|
||||||
self: LlamaForCausalLM,
|
self: LlamaForCausalLM,
|
||||||
batch: BatchInfo = None,
|
batch: BatchBucket = None,
|
||||||
k_caches: List[torch.Tensor] = None,
|
k_caches: List[torch.Tensor] = None,
|
||||||
v_caches: List[torch.Tensor] = None,
|
v_caches: List[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
@ -59,7 +59,7 @@ def llama_causal_lm_forward(
|
|||||||
|
|
||||||
def llama_model_forward(
|
def llama_model_forward(
|
||||||
self: LlamaModel,
|
self: LlamaModel,
|
||||||
batch: BatchInfo = None,
|
batch: BatchBucket = None,
|
||||||
k_caches: List[torch.Tensor] = None,
|
k_caches: List[torch.Tensor] = None,
|
||||||
v_caches: List[torch.Tensor] = None,
|
v_caches: List[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
@ -73,7 +73,7 @@ def llama_model_forward(
|
|||||||
input_ids = batch.get_1D_inputs()
|
input_ids = batch.get_1D_inputs()
|
||||||
block_tables = batch.get_block_table_tensor()
|
block_tables = batch.get_block_table_tensor()
|
||||||
sequence_lengths = batch.get_sequence_lengths()
|
sequence_lengths = batch.get_sequence_lengths()
|
||||||
batch_size = len(sequence_lengths)
|
batch_size = batch.current_batch_size
|
||||||
kv_seq_len = sequence_lengths.max().item()
|
kv_seq_len = sequence_lengths.max().item()
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
@ -71,7 +71,6 @@ class Sequence:
|
|||||||
input_token_id: List[int]
|
input_token_id: List[int]
|
||||||
block_size: int
|
block_size: int
|
||||||
sample_params: Any # SampleParams needs to be imported later.
|
sample_params: Any # SampleParams needs to be imported later.
|
||||||
block_table: torch.Tensor
|
|
||||||
eos_token_id: int
|
eos_token_id: int
|
||||||
pad_token_id: int
|
pad_token_id: int
|
||||||
max_output_len: int = 256
|
max_output_len: int = 256
|
||||||
@ -158,7 +157,6 @@ class Sequence:
|
|||||||
f"prompt={self.prompt}, "
|
f"prompt={self.prompt}, "
|
||||||
f"status={self.status.name}, "
|
f"status={self.status.name}, "
|
||||||
f"sample_params={self.sample_params}, "
|
f"sample_params={self.sample_params}, "
|
||||||
f"logical_block_number={self.block_table.shape[0]},"
|
|
||||||
f"input_len={self.input_len}),"
|
f"input_len={self.input_len}),"
|
||||||
f"output_len={self.output_len})"
|
f"output_len={self.output_len})"
|
||||||
)
|
)
|
||||||
|
140
tests/test_infer/test_batch_bucket.py
Normal file
140
tests/test_infer/test_batch_bucket.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
import torch
|
||||||
|
from transformers.models.llama import LlamaConfig
|
||||||
|
|
||||||
|
from colossalai.inference.batch_bucket import BatchBucket
|
||||||
|
from colossalai.inference.config import InferenceConfig
|
||||||
|
from colossalai.inference.kv_cache import KVCacheManager
|
||||||
|
from colossalai.inference.struct import Sequence
|
||||||
|
from colossalai.testing import parameterize
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"hidden_size": 128,
|
||||||
|
"num_attention_heads": 4,
|
||||||
|
"num_layers": 2,
|
||||||
|
"block_size": 4,
|
||||||
|
"max_batch_size": 4,
|
||||||
|
"max_input_len": 32,
|
||||||
|
"max_output_len": 8,
|
||||||
|
"dtype": torch.float16,
|
||||||
|
"tp_size": 1,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_bucket(test_config):
|
||||||
|
hidden_size = test_config.pop("hidden_size")
|
||||||
|
num_heads = test_config.pop("num_attention_heads")
|
||||||
|
num_layers = test_config.pop("num_layers")
|
||||||
|
model_config = LlamaConfig(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_hidden_layers=num_layers,
|
||||||
|
num_attention_heads=num_heads,
|
||||||
|
)
|
||||||
|
inference_config = InferenceConfig(**test_config)
|
||||||
|
|
||||||
|
# Just for testing usage. Don't create multiple cache_manager on the same device.
|
||||||
|
cache_manager = KVCacheManager(inference_config, model_config)
|
||||||
|
cache_manager_copy = KVCacheManager(inference_config, model_config)
|
||||||
|
|
||||||
|
seq_lens = [19, 20, 27]
|
||||||
|
seq1 = Sequence(
|
||||||
|
request_id=0,
|
||||||
|
prompt="", # Dummy for testing usage
|
||||||
|
input_token_id=list(range(seq_lens[0])),
|
||||||
|
block_size=4,
|
||||||
|
sample_params=None,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=2,
|
||||||
|
max_output_len=10,
|
||||||
|
)
|
||||||
|
seq2 = Sequence(
|
||||||
|
request_id=1,
|
||||||
|
prompt="", # Dummy for testing usage
|
||||||
|
input_token_id=list(range(seq_lens[1])),
|
||||||
|
block_size=4,
|
||||||
|
sample_params=None,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=2,
|
||||||
|
max_output_len=10,
|
||||||
|
)
|
||||||
|
seq3 = Sequence(
|
||||||
|
request_id=2,
|
||||||
|
prompt="", # Dummy for testing usage
|
||||||
|
input_token_id=list(range(seq_lens[2])),
|
||||||
|
block_size=4,
|
||||||
|
sample_params=None,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=2,
|
||||||
|
max_output_len=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_size = test_config["block_size"]
|
||||||
|
max_batch_size = test_config["max_batch_size"]
|
||||||
|
max_length = test_config["max_input_len"] + test_config["max_output_len"]
|
||||||
|
assert max_batch_size >= 2, "max_batch_size should be greater than 1"
|
||||||
|
|
||||||
|
bb = BatchBucket(
|
||||||
|
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2
|
||||||
|
)
|
||||||
|
bb_copy = BatchBucket(
|
||||||
|
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2
|
||||||
|
)
|
||||||
|
block_tables = bb.add_seqs([seq1, seq2])
|
||||||
|
assert block_tables.shape == (2, cache_manager.max_blocks_per_sequence)
|
||||||
|
assert torch.all(block_tables < 0), "Initialized block_tables should be negative values"
|
||||||
|
|
||||||
|
cache_manager.allocate_context_from_block_tables(block_tables, bb.seq_lengths[: bb.current_batch_size])
|
||||||
|
bb_copy.add_seqs(
|
||||||
|
[seq1, seq2], alloc_block_tables_fn=cache_manager_copy.allocate_context_from_block_tables
|
||||||
|
) # This is just for testing usage. Don't add the same sequence to different buckets.
|
||||||
|
|
||||||
|
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
|
||||||
|
max_batch_size - bb.current_batch_size
|
||||||
|
)
|
||||||
|
assert torch.equal(bb.block_tables, bb_copy.block_tables)
|
||||||
|
|
||||||
|
bb.append_batch_tokens(torch.tensor([99, 99]))
|
||||||
|
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
|
||||||
|
max_batch_size - bb.current_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_manager.allocate_tokens_from_block_tables(bb.block_tables, bb.seq_lengths, bsz=bb.current_batch_size)
|
||||||
|
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
|
||||||
|
max_batch_size - bb.current_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
bb.append_batch_tokens(torch.tensor([99, 99]))
|
||||||
|
|
||||||
|
cache_manager.allocate_tokens_from_block_tables(bb.block_tables, bb.seq_lengths, bsz=bb.current_batch_size)
|
||||||
|
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
|
||||||
|
max_batch_size - bb.current_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
bb.pop_seq_update_batch(0, free_block_table_fn=cache_manager.free_block_table)
|
||||||
|
assert bb.seq_lengths.tolist() == [bb.seqs_li[0].sentence_len] + [0] * (max_batch_size - bb.current_batch_size)
|
||||||
|
assert bb.is_compact
|
||||||
|
|
||||||
|
bb2 = BatchBucket(
|
||||||
|
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2
|
||||||
|
)
|
||||||
|
block_tables = bb2.add_seqs([seq3])
|
||||||
|
cache_manager.allocate_context_from_block_tables(block_tables, bb2.seq_lengths[: bb2.current_batch_size])
|
||||||
|
unmerged_ids = bb.merge(bb2)
|
||||||
|
assert not unmerged_ids
|
||||||
|
assert bb.is_compact
|
||||||
|
assert bb2.is_compact
|
||||||
|
assert bb.current_batch_size == 2
|
||||||
|
assert bb2.current_batch_size == 0
|
||||||
|
|
||||||
|
bb.clear(cache_manager.free_block_tables)
|
||||||
|
assert bb.current_batch_size == 0
|
||||||
|
assert bb.is_compact
|
||||||
|
assert bb.seq_lengths.tolist() == [0] * max_batch_size
|
||||||
|
assert torch.all(bb.block_tables < 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_bucket()
|
@ -15,7 +15,6 @@ def check_config_and_inference():
|
|||||||
input_token_id=[1, 2, 3],
|
input_token_id=[1, 2, 3],
|
||||||
block_size=16,
|
block_size=16,
|
||||||
sample_params=None,
|
sample_params=None,
|
||||||
block_table=None,
|
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
pad_token_id=2,
|
pad_token_id=2,
|
||||||
max_output_len=256,
|
max_output_len=256,
|
||||||
@ -27,7 +26,6 @@ def check_config_and_inference():
|
|||||||
input_token_id=[4, 5, 6],
|
input_token_id=[4, 5, 6],
|
||||||
block_size=16,
|
block_size=16,
|
||||||
sample_params=None,
|
sample_params=None,
|
||||||
block_table=None,
|
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
pad_token_id=2,
|
pad_token_id=2,
|
||||||
max_output_len=256,
|
max_output_len=256,
|
||||||
@ -39,7 +37,6 @@ def check_config_and_inference():
|
|||||||
input_token_id=[7, 8, 9],
|
input_token_id=[7, 8, 9],
|
||||||
block_size=16,
|
block_size=16,
|
||||||
sample_params=None,
|
sample_params=None,
|
||||||
block_table=None,
|
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
pad_token_id=2,
|
pad_token_id=2,
|
||||||
max_output_len=256,
|
max_output_len=256,
|
||||||
|
@ -148,6 +148,20 @@ def check_cache_manager(test_config):
|
|||||||
cache_manager.clear_all()
|
cache_manager.clear_all()
|
||||||
assert cache_manager.num_available_blocks == num_blocks
|
assert cache_manager.num_available_blocks == num_blocks
|
||||||
|
|
||||||
|
for cache_block in cache_manager._cache_blocks:
|
||||||
|
assert cache_block.available_space == block_size
|
||||||
|
|
||||||
|
# Mock batch operations (Prefill/Decoding updates)
|
||||||
|
context_lengths = torch.tensor([max_input_length, max_input_length - 1])
|
||||||
|
block_tables = torch.tensor(
|
||||||
|
[[-1 for _ in range(cache_manager.max_blocks_per_sequence)] for _ in range(2)], dtype=torch.int32
|
||||||
|
)
|
||||||
|
cache_manager.allocate_context_from_block_tables(block_tables, context_lengths)
|
||||||
|
cache_manager.allocate_tokens_from_block_tables(block_tables, context_lengths)
|
||||||
|
cache_manager.free_block_tables(block_tables)
|
||||||
|
for cache_block in cache_manager._cache_blocks:
|
||||||
|
assert cache_block.available_space == block_size
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
from transformers.models.llama import LlamaConfig
|
from transformers.models.llama import LlamaConfig
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
@ -22,17 +21,35 @@ def check_running_list():
|
|||||||
eos_token_id=0,
|
eos_token_id=0,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
sample_params=None,
|
sample_params=None,
|
||||||
block_table=1,
|
|
||||||
)
|
)
|
||||||
|
seq2 = Sequence(
|
||||||
|
request_id=2,
|
||||||
|
prompt="abc",
|
||||||
|
input_token_id=[1, 2, 3],
|
||||||
|
block_size=16,
|
||||||
|
eos_token_id=0,
|
||||||
|
pad_token_id=0,
|
||||||
|
sample_params=None,
|
||||||
|
)
|
||||||
running_list.append(seq1)
|
running_list.append(seq1)
|
||||||
|
running_list.append(seq2)
|
||||||
assert running_list.ready_for_prefill()
|
assert running_list.ready_for_prefill()
|
||||||
assert running_list.decoding == [] and running_list.prefill[0] == seq1
|
assert len(running_list.decoding) == 0
|
||||||
|
assert len(running_list.prefill) > 0 and running_list.prefill[0] == seq1
|
||||||
|
|
||||||
seq = running_list.find_seq(seq1.request_id)
|
seq = running_list.find_seq(seq1.request_id)
|
||||||
assert seq == seq1
|
assert seq == seq1
|
||||||
|
|
||||||
|
running_list.mark_prefill_running()
|
||||||
|
for seq in running_list.prefill:
|
||||||
|
assert seq.status == RequestStatus.RUNNING
|
||||||
|
|
||||||
|
running_list.move_prefill_to_decoding([seq1.request_id, seq2.request_id])
|
||||||
|
assert len(running_list.prefill) == 0
|
||||||
|
assert len(running_list.decoding) > 0 and running_list.decoding[0] == seq1
|
||||||
|
|
||||||
running_list.remove(seq1)
|
running_list.remove(seq1)
|
||||||
|
running_list.remove(seq2)
|
||||||
assert running_list.is_empty()
|
assert running_list.is_empty()
|
||||||
|
|
||||||
|
|
||||||
@ -59,7 +76,6 @@ def check_request_handler():
|
|||||||
eos_token_id=0,
|
eos_token_id=0,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
sample_params=None,
|
sample_params=None,
|
||||||
block_table=torch.tensor([-1, -1]),
|
|
||||||
)
|
)
|
||||||
request_handler.add_sequence(seq1)
|
request_handler.add_sequence(seq1)
|
||||||
# the priority should be 1
|
# the priority should be 1
|
||||||
|
Loading…
Reference in New Issue
Block a user