[SpecDec] Fix inputs for speculation and revise past KV trimming (#5449)

* fix drafter pastkv and usage of batch bucket
This commit is contained in:
Yuanheng Zhao
2024-03-12 17:57:01 +08:00
committed by Yuanheng
parent a37f82629d
commit 912e24b2aa
3 changed files with 40 additions and 19 deletions

View File

@@ -372,18 +372,22 @@ class BatchBucket:
seq.check_finish()
self._sequence_lengths[: self.current_batch_size] += 1
def revoke_batch_tokens(self, n: int) -> None:
def revoke_batch_tokens(self, n_tokens: int, n_seqs: int = 1) -> None:
"""Revoke the last n output tokens of the sequences in the batch
Args:
n (int): The number of output tokens to revoke from each sequence.
n_tokens (int): The number of output tokens to revoke from each sequence.
It does not count in the context tokens (input tokens).
n_seqs (int): The first n sequences to revoke tokens from. Defaults to 1.
For now, speculative decoding only supports batch size 1.
"""
if n >= 1:
for seq_id, seq in self._sequences_dict.items():
assert seq.output_len >= n, "Revoking len exceeds the current output len of the sequence"
seq.output_token_id = seq.output_token_id[:-n]
self._sequence_lengths -= n
if n_tokens >= 1:
seqs_iter = iter(self._sequences_dict.items())
for _ in range(n_seqs):
seq_id, seq = next(seqs_iter)
assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence"
seq.output_token_id = seq.output_token_id[:-n_tokens]
self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
"""Clear all the sequences in the batch.