mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[SpecDec] Fix inputs for speculation and revise past KV trimming (#5449)
* fix drafter pastkv and usage of batch bucket
This commit is contained in:
@@ -269,24 +269,26 @@ class InferenceEngine:
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
|
||||
# using speculative decoding for subsequent generations
|
||||
self.use_spec_dec = True
|
||||
|
||||
def disable_spec_dec(self) -> None:
|
||||
"""Disable using speculative decoding for subsequent generations."""
|
||||
self.request_handler.unset_spec_dec_mode()
|
||||
# set back to the maximum number of tokens to speculate
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
self.use_spec_dec = False
|
||||
return
|
||||
|
||||
def clear_spec_dec(self) -> None:
|
||||
"""Clear relatable structures of speculative decoding, if exist."""
|
||||
if self.use_spec_dec:
|
||||
self.disable_spec_dec()
|
||||
if self.drafter_model or self.drafter:
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
torch.cuda.empty_cache()
|
||||
self.use_spec_dec = False
|
||||
return
|
||||
|
||||
def steps_spec_dec(self) -> List[Sequence]:
|
||||
"""
|
||||
@@ -297,7 +299,6 @@ class InferenceEngine:
|
||||
List[Sequence]: finished sequences generated by one step.
|
||||
"""
|
||||
batch = self.request_handler.schedule() # prefill batch
|
||||
batch.set_use_spec_dec(self.n_spec_tokens) # set batch to use-spec-dec mode
|
||||
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model
|
||||
@@ -316,19 +317,19 @@ class InferenceEngine:
|
||||
already_allocated_kv_len = batch.seq_lengths[0].item()
|
||||
input_ids = batch.get_1D_inputs_spec_dec(1)
|
||||
|
||||
batch.reset_use_spec_dec() # reset batch use-spec-dec mode
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
while True:
|
||||
# HACK Retrieve the running batch
|
||||
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
|
||||
batch = self.request_handler.running_bb # running batch
|
||||
batch.set_use_spec_dec(self.n_spec_tokens)
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
|
||||
# 3. Decoding - Drafter model speculates `n` tokens
|
||||
drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
drafter_spec_length = drafter_out.speculated_length
|
||||
|
||||
for next_token_id_spec in next_token_ids_spec:
|
||||
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
|
||||
@@ -343,22 +344,26 @@ class InferenceEngine:
|
||||
|
||||
# 5. Compare and process the results
|
||||
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||
n_matches = self.n_spec_tokens if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
||||
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
||||
|
||||
# revoke appended tokens for each Sequence in the current batch
|
||||
batch.revoke_batch_tokens(self.n_spec_tokens - n_matches) # revoke drafted tokens
|
||||
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
|
||||
# append the last correct token generated by the main model
|
||||
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
||||
input_ids = batch.get_1D_inputs_spec_dec(1)
|
||||
|
||||
# trim past key values of the drafter model
|
||||
drafter_past_key_values = Drafter.trim_kv_cache(drafter_past_key_values, self.n_spec_tokens - n_matches - 1)
|
||||
drafter_past_key_values = Drafter.trim_kv_cache(
|
||||
drafter_past_key_values, drafter_spec_length - n_matches - 1
|
||||
)
|
||||
# prepare inputs for the next round of speculation
|
||||
n = 1 if n_matches < drafter_spec_length else 2
|
||||
input_ids = batch.get_1D_inputs_spec_dec(n)
|
||||
|
||||
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
|
||||
finished_sequences = self.request_handler.update()
|
||||
if len(finished_sequences) > 0:
|
||||
break
|
||||
|
||||
batch.reset_use_spec_dec()
|
||||
|
||||
return finished_sequences
|
||||
|
||||
def generate(
|
||||
|
Reference in New Issue
Block a user