mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[Inference]Add Streaming LLM (#5745)
* Add Streaming LLM * add some parameters to llama_generation.py * verify streamingllm config * add test_streamingllm.py * modified according to the opinions of review * add Citation * change _block_tables tolist
This commit is contained in:
@@ -667,6 +667,11 @@ class InferenceEngine:
|
||||
elif max_length is not None:
|
||||
max_new_tokens = max_length - len(prompts_token_ids[i])
|
||||
|
||||
if not self.inference_config.enable_streamingllm:
|
||||
assert (
|
||||
self.inference_config.max_output_len >= max_new_tokens
|
||||
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
|
||||
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
@@ -754,6 +759,13 @@ class InferenceEngine:
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if self.inference_config.enable_streamingllm:
|
||||
updated_block_ids = batch.streamingllm_update_batch(
|
||||
self.inference_config.start_token_size, self.inference_config.generated_token_size
|
||||
)
|
||||
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
|
||||
|
||||
next_tokens = search_tokens(
|
||||
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
||||
)
|
||||
|
Reference in New Issue
Block a user