[Inference] Optimize and Refactor Inference Batching/Scheduling (#5367)

* add kvcache manager funcs for batching

* add batch bucket for batching

* revise RunningList struct in handler

* add kvcache/batch funcs for compatibility

* use new batching methods

* fix indexing bugs

* revise abort logic

* use cpu seq lengths/block tables

* rm unused attr in Sequence

* fix type conversion/default arg

* add and revise pytests

* revise pytests, rm unused tests

* rm unused statements

* fix pop finished indexing issue

* fix: use index in batch when retrieving inputs/update seqs

* use dict instead of odict in batch struct

* arg type hinting

* fix make compress

* refine comments

* fix: pop_n_seqs to pop the first n seqs

* add check in request handler

* remove redundant conversion

* fix test for request handler

* fix pop method in batch bucket

* fix prefill adding
This commit is contained in:
Yuanheng Zhao
2024-02-19 17:18:20 +08:00
committed by GitHub
parent 8c69debdc7
commit b21aac5bae
11 changed files with 902 additions and 112 deletions

View File

@@ -1,5 +1,4 @@
import pytest
import torch
from transformers.models.llama import LlamaConfig
import colossalai
@@ -22,17 +21,35 @@ def check_running_list():
eos_token_id=0,
pad_token_id=0,
sample_params=None,
block_table=1,
)
seq2 = Sequence(
request_id=2,
prompt="abc",
input_token_id=[1, 2, 3],
block_size=16,
eos_token_id=0,
pad_token_id=0,
sample_params=None,
)
running_list.append(seq1)
running_list.append(seq2)
assert running_list.ready_for_prefill()
assert running_list.decoding == [] and running_list.prefill[0] == seq1
assert len(running_list.decoding) == 0
assert len(running_list.prefill) > 0 and running_list.prefill[0] == seq1
seq = running_list.find_seq(seq1.request_id)
assert seq == seq1
running_list.mark_prefill_running()
for seq in running_list.prefill:
assert seq.status == RequestStatus.RUNNING
running_list.move_prefill_to_decoding([seq1.request_id, seq2.request_id])
assert len(running_list.prefill) == 0
assert len(running_list.decoding) > 0 and running_list.decoding[0] == seq1
running_list.remove(seq1)
running_list.remove(seq2)
assert running_list.is_empty()
@@ -59,7 +76,6 @@ def check_request_handler():
eos_token_id=0,
pad_token_id=0,
sample_params=None,
block_table=torch.tensor([-1, -1]),
)
request_handler.add_sequence(seq1)
# the priority should be 1