[Inference] Optimize and Refactor Inference Batching/Scheduling (#5367)

* add kvcache manager funcs for batching

* add batch bucket for batching

* revise RunningList struct in handler

* add kvcache/batch funcs for compatibility

* use new batching methods

* fix indexing bugs

* revise abort logic

* use cpu seq lengths/block tables

* rm unused attr in Sequence

* fix type conversion/default arg

* add and revise pytests

* revise pytests, rm unused tests

* rm unused statements

* fix pop finished indexing issue

* fix: use index in batch when retrieving inputs/update seqs

* use dict instead of odict in batch struct

* arg type hinting

* fix make compress

* refine comments

* fix: pop_n_seqs to pop the first n seqs

* add check in request handler

* remove redundant conversion

* fix test for request handler

* fix pop method in batch bucket

* fix prefill adding
This commit is contained in:
Yuanheng Zhao
2024-02-19 17:18:20 +08:00
committed by GitHub
parent 8c69debdc7
commit b21aac5bae
11 changed files with 902 additions and 112 deletions

View File

@@ -12,8 +12,8 @@ from transformers.models.llama.modeling_llama import (
LlamaModel,
)
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.struct import BatchInfo
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_kv_to_blocked_cache,
@@ -34,7 +34,7 @@ except ImportError:
def llama_causal_lm_forward(
self: LlamaForCausalLM,
batch: BatchInfo = None,
batch: BatchBucket = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
@@ -59,7 +59,7 @@ def llama_causal_lm_forward(
def llama_model_forward(
self: LlamaModel,
batch: BatchInfo = None,
batch: BatchBucket = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
@@ -73,7 +73,7 @@ def llama_model_forward(
input_ids = batch.get_1D_inputs()
block_tables = batch.get_block_table_tensor()
sequence_lengths = batch.get_sequence_lengths()
batch_size = len(sequence_lengths)
batch_size = batch.current_batch_size
kv_seq_len = sequence_lengths.max().item()
hidden_states = self.embed_tokens(input_ids)