[SpecDec] Fix inputs for speculation and revise past KV trimming (#5449)

* fix drafter pastkv and usage of batch bucket
This commit is contained in:
Yuanheng Zhao
2024-03-12 17:57:01 +08:00
committed by Yuanheng
parent a37f82629d
commit 912e24b2aa
3 changed files with 40 additions and 19 deletions

View File

@@ -181,6 +181,14 @@ class RequestHandler:
def get_kvcache(self):
return self.cache_manager.get_kv_cache()
def set_spec_dec_mode(self, n_spec_tokens: int):
self.prefill_bb.set_use_spec_dec(n_spec_tokens)
self.running_bb.set_use_spec_dec(n_spec_tokens)
def unset_spec_dec_mode(self):
self.prefill_bb.reset_use_spec_dec()
self.running_bb.reset_use_spec_dec()
def schedule(self):
"""
The main logic of request handler.
@@ -208,7 +216,11 @@ class RequestHandler:
lst.remove(seq)
if self.running_list.ready_for_prefill():
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.running_bb.available_batch_size)
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.prefill_bb.available_batch_size)
# overwrite the number of sequences to add to 1 if use_spec_dec is enabled
# TODO (zhaoyuanheng): support speculative decoding for batch size > 1
if self.prefill_bb.use_spec_dec:
num_seqs_to_add = 1
for seq in self.running_list.prefill[:num_seqs_to_add]:
seq.mark_running()