mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[Inference] Optimize and Refactor Inference Batching/Scheduling (#5367)
* add kvcache manager funcs for batching * add batch bucket for batching * revise RunningList struct in handler * add kvcache/batch funcs for compatibility * use new batching methods * fix indexing bugs * revise abort logic * use cpu seq lengths/block tables * rm unused attr in Sequence * fix type conversion/default arg * add and revise pytests * revise pytests, rm unused tests * rm unused statements * fix pop finished indexing issue * fix: use index in batch when retrieving inputs/update seqs * use dict instead of odict in batch struct * arg type hinting * fix make compress * refine comments * fix: pop_n_seqs to pop the first n seqs * add check in request handler * remove redundant conversion * fix test for request handler * fix pop method in batch bucket * fix prefill adding
This commit is contained in:
@@ -42,7 +42,7 @@ class InferenceEngine:
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
tokenizer: [Union[PreTrainedTokenizer, PreTrainedTokenizerFast]],
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
inference_config: InferenceConfig,
|
||||
verbose: bool = False,
|
||||
model_policy: Policy = None,
|
||||
@@ -254,20 +254,12 @@ class InferenceEngine:
|
||||
else:
|
||||
prompt = prompts[i]
|
||||
|
||||
max_blocks_per_sequence = (
|
||||
self.inference_config.max_input_len
|
||||
+ self.inference_config.max_output_len
|
||||
+ self.inference_config.block_size
|
||||
- 1
|
||||
) // self.inference_config.block_size
|
||||
block_table = torch.full([max_blocks_per_sequence], -1, device=self.device)
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
prompts_token_ids[i],
|
||||
block_size,
|
||||
None,
|
||||
block_table,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.tokenizer.pad_token_id,
|
||||
self.inference_config.max_output_len,
|
||||
|
Reference in New Issue
Block a user