[Inference]Add Streaming LLM (#5745)

* Add Streaming LLM

* add some parameters to llama_generation.py

* verify streamingllm config

* add test_streamingllm.py

* modified according to the opinions of review

* add Citation

* change _block_tables tolist
This commit is contained in:
yuehuayingxueluo
2024-06-05 10:51:19 +08:00
committed by GitHub
parent ee6fd38373
commit b45000f839
8 changed files with 276 additions and 12 deletions

View File

@@ -667,6 +667,11 @@ class InferenceEngine:
elif max_length is not None:
max_new_tokens = max_length - len(prompts_token_ids[i])
if not self.inference_config.enable_streamingllm:
assert (
self.inference_config.max_output_len >= max_new_tokens
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
sequence = Sequence(
request_id,
prompt,
@@ -754,6 +759,13 @@ class InferenceEngine:
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
if self.inference_config.enable_streamingllm:
updated_block_ids = batch.streamingllm_update_batch(
self.inference_config.start_token_size, self.inference_config.generated_token_size
)
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
next_tokens = search_tokens(
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
)