mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[SpecDec] Fix inputs for speculation and revise past KV trimming (#5449)
* fix drafter pastkv and usage of batch bucket
This commit is contained in:
@@ -181,6 +181,14 @@ class RequestHandler:
|
||||
def get_kvcache(self):
|
||||
return self.cache_manager.get_kv_cache()
|
||||
|
||||
def set_spec_dec_mode(self, n_spec_tokens: int):
|
||||
self.prefill_bb.set_use_spec_dec(n_spec_tokens)
|
||||
self.running_bb.set_use_spec_dec(n_spec_tokens)
|
||||
|
||||
def unset_spec_dec_mode(self):
|
||||
self.prefill_bb.reset_use_spec_dec()
|
||||
self.running_bb.reset_use_spec_dec()
|
||||
|
||||
def schedule(self):
|
||||
"""
|
||||
The main logic of request handler.
|
||||
@@ -208,7 +216,11 @@ class RequestHandler:
|
||||
lst.remove(seq)
|
||||
|
||||
if self.running_list.ready_for_prefill():
|
||||
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.running_bb.available_batch_size)
|
||||
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.prefill_bb.available_batch_size)
|
||||
# overwrite the number of sequences to add to 1 if use_spec_dec is enabled
|
||||
# TODO (zhaoyuanheng): support speculative decoding for batch size > 1
|
||||
if self.prefill_bb.use_spec_dec:
|
||||
num_seqs_to_add = 1
|
||||
|
||||
for seq in self.running_list.prefill[:num_seqs_to_add]:
|
||||
seq.mark_running()
|
||||
|
Reference in New Issue
Block a user