[Inference]Add Streaming LLM (#5745)

* Add Streaming LLM

* add some parameters to llama_generation.py

* verify streamingllm config

* add test_streamingllm.py

* modified according to the opinions of review

* add Citation

* change _block_tables tolist
This commit is contained in:
yuehuayingxueluo
2024-06-05 10:51:19 +08:00
committed by GitHub
parent ee6fd38373
commit b45000f839
8 changed files with 276 additions and 12 deletions

View File

@@ -31,6 +31,9 @@ class BatchBucket:
fd_interm_tensor=None,
device=None,
dtype=torch.float16,
enable_streamingllm: bool = False,
start_token_size: int = 4,
generated_token_size: int = 512,
):
self.num_heads = num_heads
self.head_dim = head_dim
@@ -45,12 +48,19 @@ class BatchBucket:
self._use_spec_dec = False
self._num_tokens_to_verify = None
self.enable_streamingllm = enable_streamingllm
self.start_token_size = start_token_size
self.generated_token_size = generated_token_size
self._current_batch_size = 0
self._sequences_dict = dict()
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32)
self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths)
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
if enable_streamingllm:
max_blocks_per_seq = (start_token_size + generated_token_size + block_size - 1) // block_size + 1
else:
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32)
self._block_tables_helper = torch.full_like(self._block_tables, -1)
@@ -109,6 +119,33 @@ class BatchBucket:
out.append(seq.input_token_id + seq.output_token_id)
return out
def streamingllm_update_batch(self, start_token_size: int, generated_token_size: int):
"""
Update sequence_lengths and block_tables when it is necessary to swap out a block.
"""
updated_block_ids = []
if self.current_batch_size > 0:
need_update = False
sequence_lengths_list = self._sequence_lengths.tolist()
block_tables_list = self._block_tables[: self._current_batch_size].tolist()
for batch_id in range(self.current_batch_size):
# We assume that the start token occupies the entire first block.
if sequence_lengths_list[batch_id] == start_token_size + generated_token_size + self.block_size - 1:
need_update = True
sequence_lengths_list[batch_id] = start_token_size + generated_token_size - 1
block_id = block_tables_list[batch_id].pop(1)
updated_block_ids.append(block_id)
block_tables_list[batch_id].append(-1)
if need_update:
self._sequence_lengths = torch.tensor(
sequence_lengths_list, dtype=self._sequence_lengths.dtype, device=self.device
)
self._block_tables = torch.tensor(block_tables_list, dtype=self._block_tables.dtype, device=self.device)
return updated_block_ids
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
"""Set batch bucket to use speculatvie decoding.
This will notify the adjust the lengths of inputs during modeling,