[Inference] Optimize and Refactor Inference Batching/Scheduling (#5367)

* add kvcache manager funcs for batching

* add batch bucket for batching

* revise RunningList struct in handler

* add kvcache/batch funcs for compatibility

* use new batching methods

* fix indexing bugs

* revise abort logic

* use cpu seq lengths/block tables

* rm unused attr in Sequence

* fix type conversion/default arg

* add and revise pytests

* revise pytests, rm unused tests

* rm unused statements

* fix pop finished indexing issue

* fix: use index in batch when retrieving inputs/update seqs

* use dict instead of odict in batch struct

* arg type hinting

* fix make compress

* refine comments

* fix: pop_n_seqs to pop the first n seqs

* add check in request handler

* remove redundant conversion

* fix test for request handler

* fix pop method in batch bucket

* fix prefill adding
This commit is contained in:
Yuanheng Zhao
2024-02-19 17:18:20 +08:00
committed by GitHub
parent 8c69debdc7
commit b21aac5bae
11 changed files with 902 additions and 112 deletions

View File

@@ -42,7 +42,7 @@ class InferenceEngine:
def __init__(
self,
model: nn.Module,
tokenizer: [Union[PreTrainedTokenizer, PreTrainedTokenizerFast]],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
verbose: bool = False,
model_policy: Policy = None,
@@ -254,20 +254,12 @@ class InferenceEngine:
else:
prompt = prompts[i]
max_blocks_per_sequence = (
self.inference_config.max_input_len
+ self.inference_config.max_output_len
+ self.inference_config.block_size
- 1
) // self.inference_config.block_size
block_table = torch.full([max_blocks_per_sequence], -1, device=self.device)
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
block_table,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len,