diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index e157a9215..d9aa01091 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -372,18 +372,22 @@ class BatchBucket: seq.check_finish() self._sequence_lengths[: self.current_batch_size] += 1 - def revoke_batch_tokens(self, n: int) -> None: + def revoke_batch_tokens(self, n_tokens: int, n_seqs: int = 1) -> None: """Revoke the last n output tokens of the sequences in the batch Args: - n (int): The number of output tokens to revoke from each sequence. + n_tokens (int): The number of output tokens to revoke from each sequence. It does not count in the context tokens (input tokens). + n_seqs (int): The first n sequences to revoke tokens from. Defaults to 1. + For now, speculative decoding only supports batch size 1. """ - if n >= 1: - for seq_id, seq in self._sequences_dict.items(): - assert seq.output_len >= n, "Revoking len exceeds the current output len of the sequence" - seq.output_token_id = seq.output_token_id[:-n] - self._sequence_lengths -= n + if n_tokens >= 1: + seqs_iter = iter(self._sequences_dict.items()) + for _ in range(n_seqs): + seq_id, seq = next(seqs_iter) + assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence" + seq.output_token_id = seq.output_token_id[:-n_tokens] + self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]: """Clear all the sequences in the batch. diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 672d5a959..7015c1f3f 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -269,24 +269,26 @@ class InferenceEngine: device=self.device, dtype=self.dtype, ) + self.request_handler.set_spec_dec_mode(self.n_spec_tokens) # using speculative decoding for subsequent generations self.use_spec_dec = True def disable_spec_dec(self) -> None: """Disable using speculative decoding for subsequent generations.""" + self.request_handler.unset_spec_dec_mode() # set back to the maximum number of tokens to speculate self.n_spec_tokens = self.inference_config.max_n_spec_tokens self.use_spec_dec = False - return def clear_spec_dec(self) -> None: """Clear relatable structures of speculative decoding, if exist.""" + if self.use_spec_dec: + self.disable_spec_dec() if self.drafter_model or self.drafter: self.drafter_model = None self.drafter = None torch.cuda.empty_cache() self.use_spec_dec = False - return def steps_spec_dec(self) -> List[Sequence]: """ @@ -297,7 +299,6 @@ class InferenceEngine: List[Sequence]: finished sequences generated by one step. """ batch = self.request_handler.schedule() # prefill batch - batch.set_use_spec_dec(self.n_spec_tokens) # set batch to use-spec-dec mode assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." input_ids = batch.get_1D_inputs() # bsz 1 for drafter model @@ -316,19 +317,19 @@ class InferenceEngine: already_allocated_kv_len = batch.seq_lengths[0].item() input_ids = batch.get_1D_inputs_spec_dec(1) - batch.reset_use_spec_dec() # reset batch use-spec-dec mode finished_sequences = self.request_handler.update() while True: # HACK Retrieve the running batch # Using RequestHandler.schedule here will re-allocate same kv cache for the batch batch = self.request_handler.running_bb # running batch - batch.set_use_spec_dec(self.n_spec_tokens) + assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." # 3. Decoding - Drafter model speculates `n` tokens drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values) next_token_ids_spec = drafter_out.next_tokens drafter_past_key_values = drafter_out.past_key_values + drafter_spec_length = drafter_out.speculated_length for next_token_id_spec in next_token_ids_spec: self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0)) @@ -343,22 +344,26 @@ class InferenceEngine: # 5. Compare and process the results diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) - n_matches = self.n_spec_tokens if diff_indexes.size(0) == 0 else diff_indexes[0][0].item() + n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item() + # revoke appended tokens for each Sequence in the current batch - batch.revoke_batch_tokens(self.n_spec_tokens - n_matches) # revoke drafted tokens + batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens # append the last correct token generated by the main model self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0)) - input_ids = batch.get_1D_inputs_spec_dec(1) + # trim past key values of the drafter model - drafter_past_key_values = Drafter.trim_kv_cache(drafter_past_key_values, self.n_spec_tokens - n_matches - 1) + drafter_past_key_values = Drafter.trim_kv_cache( + drafter_past_key_values, drafter_spec_length - n_matches - 1 + ) + # prepare inputs for the next round of speculation + n = 1 if n_matches < drafter_spec_length else 2 + input_ids = batch.get_1D_inputs_spec_dec(n) self.request_handler.update_batch_finished(batch, generation_config=self.generation_config) finished_sequences = self.request_handler.update() if len(finished_sequences) > 0: break - batch.reset_use_spec_dec() - return finished_sequences def generate( diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 6c1a232e2..327a7e9ce 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -181,6 +181,14 @@ class RequestHandler: def get_kvcache(self): return self.cache_manager.get_kv_cache() + def set_spec_dec_mode(self, n_spec_tokens: int): + self.prefill_bb.set_use_spec_dec(n_spec_tokens) + self.running_bb.set_use_spec_dec(n_spec_tokens) + + def unset_spec_dec_mode(self): + self.prefill_bb.reset_use_spec_dec() + self.running_bb.reset_use_spec_dec() + def schedule(self): """ The main logic of request handler. @@ -208,7 +216,11 @@ class RequestHandler: lst.remove(seq) if self.running_list.ready_for_prefill(): - num_seqs_to_add = min(self.running_list.prefill_seq_num, self.running_bb.available_batch_size) + num_seqs_to_add = min(self.running_list.prefill_seq_num, self.prefill_bb.available_batch_size) + # overwrite the number of sequences to add to 1 if use_spec_dec is enabled + # TODO (zhaoyuanheng): support speculative decoding for batch size > 1 + if self.prefill_bb.use_spec_dec: + num_seqs_to_add = 1 for seq in self.running_list.prefill[:num_seqs_to_add]: seq.mark_running()