mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[Inference] Optimize and Refactor Inference Batching/Scheduling (#5367)
* add kvcache manager funcs for batching * add batch bucket for batching * revise RunningList struct in handler * add kvcache/batch funcs for compatibility * use new batching methods * fix indexing bugs * revise abort logic * use cpu seq lengths/block tables * rm unused attr in Sequence * fix type conversion/default arg * add and revise pytests * revise pytests, rm unused tests * rm unused statements * fix pop finished indexing issue * fix: use index in batch when retrieving inputs/update seqs * use dict instead of odict in batch struct * arg type hinting * fix make compress * refine comments * fix: pop_n_seqs to pop the first n seqs * add check in request handler * remove redundant conversion * fix test for request handler * fix pop method in batch bucket * fix prefill adding
This commit is contained in:
@@ -12,8 +12,8 @@ from transformers.models.llama.modeling_llama import (
|
||||
LlamaModel,
|
||||
)
|
||||
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.struct import BatchInfo
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
copy_kv_to_blocked_cache,
|
||||
@@ -34,7 +34,7 @@ except ImportError:
|
||||
|
||||
def llama_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
batch: BatchInfo = None,
|
||||
batch: BatchBucket = None,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
):
|
||||
@@ -59,7 +59,7 @@ def llama_causal_lm_forward(
|
||||
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
batch: BatchInfo = None,
|
||||
batch: BatchBucket = None,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
):
|
||||
@@ -73,7 +73,7 @@ def llama_model_forward(
|
||||
input_ids = batch.get_1D_inputs()
|
||||
block_tables = batch.get_block_table_tensor()
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
batch_size = len(sequence_lengths)
|
||||
batch_size = batch.current_batch_size
|
||||
kv_seq_len = sequence_lengths.max().item()
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
Reference in New Issue
Block a user