mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[Inference] Optimize and Refactor Inference Batching/Scheduling (#5367)
* add kvcache manager funcs for batching * add batch bucket for batching * revise RunningList struct in handler * add kvcache/batch funcs for compatibility * use new batching methods * fix indexing bugs * revise abort logic * use cpu seq lengths/block tables * rm unused attr in Sequence * fix type conversion/default arg * add and revise pytests * revise pytests, rm unused tests * rm unused statements * fix pop finished indexing issue * fix: use index in batch when retrieving inputs/update seqs * use dict instead of odict in batch struct * arg type hinting * fix make compress * refine comments * fix: pop_n_seqs to pop the first n seqs * add check in request handler * remove redundant conversion * fix test for request handler * fix pop method in batch bucket * fix prefill adding
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
import torch
|
||||
from transformers.models.llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
@@ -22,17 +21,35 @@ def check_running_list():
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
sample_params=None,
|
||||
block_table=1,
|
||||
)
|
||||
|
||||
seq2 = Sequence(
|
||||
request_id=2,
|
||||
prompt="abc",
|
||||
input_token_id=[1, 2, 3],
|
||||
block_size=16,
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
sample_params=None,
|
||||
)
|
||||
running_list.append(seq1)
|
||||
running_list.append(seq2)
|
||||
assert running_list.ready_for_prefill()
|
||||
assert running_list.decoding == [] and running_list.prefill[0] == seq1
|
||||
assert len(running_list.decoding) == 0
|
||||
assert len(running_list.prefill) > 0 and running_list.prefill[0] == seq1
|
||||
|
||||
seq = running_list.find_seq(seq1.request_id)
|
||||
assert seq == seq1
|
||||
|
||||
running_list.mark_prefill_running()
|
||||
for seq in running_list.prefill:
|
||||
assert seq.status == RequestStatus.RUNNING
|
||||
|
||||
running_list.move_prefill_to_decoding([seq1.request_id, seq2.request_id])
|
||||
assert len(running_list.prefill) == 0
|
||||
assert len(running_list.decoding) > 0 and running_list.decoding[0] == seq1
|
||||
|
||||
running_list.remove(seq1)
|
||||
running_list.remove(seq2)
|
||||
assert running_list.is_empty()
|
||||
|
||||
|
||||
@@ -59,7 +76,6 @@ def check_request_handler():
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
sample_params=None,
|
||||
block_table=torch.tensor([-1, -1]),
|
||||
)
|
||||
request_handler.add_sequence(seq1)
|
||||
# the priority should be 1
|
||||
|
Reference in New Issue
Block a user