[Inference]Add Streaming LLM (#5745)

* Add Streaming LLM

* add some parameters to llama_generation.py

* verify streamingllm config

* add test_streamingllm.py

* modified according to the opinions of review

* add Citation

* change _block_tables tolist
This commit is contained in:
yuehuayingxueluo
2024-06-05 10:51:19 +08:00
committed by GitHub
parent ee6fd38373
commit b45000f839
8 changed files with 276 additions and 12 deletions

View File

@@ -157,6 +157,9 @@ class RequestHandler:
fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype,
device=device,
enable_streamingllm=inference_config.enable_streamingllm,
start_token_size=inference_config.start_token_size,
generated_token_size=inference_config.generated_token_size,
)
self.prefill_bb = BatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size,
@@ -168,6 +171,9 @@ class RequestHandler:
fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype,
device=device,
enable_streamingllm=inference_config.enable_streamingllm,
start_token_size=inference_config.start_token_size,
generated_token_size=inference_config.generated_token_size,
)
def _init_cache(self, model_config):
@@ -350,6 +356,12 @@ class RequestHandler:
return finished_seqs
def streamingllm_free_block_tables(self, updated_block_ids: List[int]):
"""
Free the block that needs to be swapped out.
"""
self.cache_manager.streamingllm_free_block_tables(updated_block_ids)
class RPCRequestHandler(RequestHandler):
"""