mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[SpecDec] Fix inputs for speculation and revise past KV trimming (#5449)
* fix drafter pastkv and usage of batch bucket
This commit is contained in:
@@ -372,18 +372,22 @@ class BatchBucket:
|
||||
seq.check_finish()
|
||||
self._sequence_lengths[: self.current_batch_size] += 1
|
||||
|
||||
def revoke_batch_tokens(self, n: int) -> None:
|
||||
def revoke_batch_tokens(self, n_tokens: int, n_seqs: int = 1) -> None:
|
||||
"""Revoke the last n output tokens of the sequences in the batch
|
||||
|
||||
Args:
|
||||
n (int): The number of output tokens to revoke from each sequence.
|
||||
n_tokens (int): The number of output tokens to revoke from each sequence.
|
||||
It does not count in the context tokens (input tokens).
|
||||
n_seqs (int): The first n sequences to revoke tokens from. Defaults to 1.
|
||||
For now, speculative decoding only supports batch size 1.
|
||||
"""
|
||||
if n >= 1:
|
||||
for seq_id, seq in self._sequences_dict.items():
|
||||
assert seq.output_len >= n, "Revoking len exceeds the current output len of the sequence"
|
||||
seq.output_token_id = seq.output_token_id[:-n]
|
||||
self._sequence_lengths -= n
|
||||
if n_tokens >= 1:
|
||||
seqs_iter = iter(self._sequences_dict.items())
|
||||
for _ in range(n_seqs):
|
||||
seq_id, seq = next(seqs_iter)
|
||||
assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence"
|
||||
seq.output_token_id = seq.output_token_id[:-n_tokens]
|
||||
self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens
|
||||
|
||||
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
|
||||
"""Clear all the sequences in the batch.
|
||||
|
Reference in New Issue
Block a user