[Inference]Add Streaming LLM (#5745)

* Add Streaming LLM

* add some parameters to llama_generation.py

* verify streamingllm config

* add test_streamingllm.py

* modified according to the opinions of review

* add Citation

* change _block_tables tolist
This commit is contained in:
yuehuayingxueluo
2024-06-05 10:51:19 +08:00
committed by GitHub
parent ee6fd38373
commit b45000f839
8 changed files with 276 additions and 12 deletions

View File

@@ -78,10 +78,16 @@ class KVCacheManager:
self.max_output_length = config.max_output_len
# Cache block settings
self.block_size = config.block_size
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size
if config.enable_streamingllm:
self.max_blocks_per_sequence = (
config.start_token_size + config.generated_token_size + self.block_size - 1
) // self.block_size + 1
else:
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
# Physical cache allocation
@@ -446,6 +452,20 @@ class KVCacheManager:
self._available_blocks = self.num_blocks
self._block_states[:] = 1
def streamingllm_free_block_tables(self, updated_block_ids: List[int]):
"""
Free the block that needs to be swapped out.
"""
for global_block_id in updated_block_ids:
if global_block_id < 0:
return
block: CacheBlock = self._cache_blocks[global_block_id]
block.remove_ref()
if not block.has_ref():
block.allocated_size = 0
self._available_blocks += 1
self._block_states[global_block_id] = 1
def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the tensor corresponding to the cache block with the prompted id for a specific layer."""
return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx]
@@ -533,10 +553,16 @@ class RPCKVCacheManager(KVCacheManager):
self.max_output_length = config.max_output_len
# Cache block settings
self.block_size = config.block_size
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size
if config.enable_streamingllm:
self.max_blocks_per_sequence = (
config.start_token_size + config.generated_token_size + self.block_size - 1
) // self.block_size + 1
else:
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
# Logical cache blocks allocation