mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[Inference/SpecDec] Add Speculative Decoding Implementation (#5423)
* fix flash decoding mask during verification * add spec-dec * add test for spec-dec * revise drafter init * remove drafter sampling * retire past kv in drafter * (trivial) rename attrs * (trivial) rename arg * revise how we enable/disable spec-dec
This commit is contained in:
@@ -12,6 +12,7 @@ from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.spec import Drafter
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
@@ -52,19 +53,26 @@ class InferenceEngine:
|
||||
verbose: bool = False,
|
||||
model_policy: Policy = None,
|
||||
) -> None:
|
||||
assert inference_config, "Please provide inference_config."
|
||||
assert tokenizer, "Please provide a tokenizer, either a defined one or str"
|
||||
self.inference_config = inference_config
|
||||
self.model_config = model.config
|
||||
self.model = model
|
||||
self.device = torch.device("cuda")
|
||||
self.dtype = inference_config.dtype
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.high_precision = inference_config.high_precision
|
||||
model = model.eval()
|
||||
model = model.cuda()
|
||||
model.to(self.dtype)
|
||||
self._verify_args()
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
model.eval()
|
||||
model = model.to(self.dtype)
|
||||
model = model.to(self.device)
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = False
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
|
||||
if model_policy is None:
|
||||
if self.inference_config.pad_input:
|
||||
@@ -174,21 +182,18 @@ class InferenceEngine:
|
||||
if self.verbose:
|
||||
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
||||
|
||||
def _verify_config(self) -> None:
|
||||
"""
|
||||
Verify the input config
|
||||
"""
|
||||
def _verify_args(self) -> None:
|
||||
"""Verify the input args"""
|
||||
if not isinstance(self.inference_config, InferenceConfig):
|
||||
raise TypeError("Invalid type of inference config provided.")
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
||||
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
|
||||
self.tokenizer, PreTrainedTokenizer
|
||||
):
|
||||
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
|
||||
raise TypeError(
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||
)
|
||||
assert (
|
||||
self.model.__class__.__name__ in _supported_models
|
||||
), f"Model {self.model.__class__.__name__} is not supported."
|
||||
if self.model.__class__.__name__ not in _supported_models:
|
||||
raise ValueError(f"Model {self.model.__class__.__name__} is not supported.")
|
||||
|
||||
def _shardformer(
|
||||
self,
|
||||
@@ -224,6 +229,138 @@ class InferenceEngine:
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model
|
||||
|
||||
def enable_spec_dec(self, drafter_model: nn.Module = None, n_spec_tokens: int = None) -> None:
|
||||
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
|
||||
|
||||
Args:
|
||||
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
|
||||
If provided, the previous drafter and drafter model, if exist, will be overwritten.
|
||||
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
|
||||
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
|
||||
|
||||
```python
|
||||
...
|
||||
engine = InferenceEngine(model, tokenizer, inference_config)
|
||||
|
||||
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||
engine.generate(...) # Speculative Decoding
|
||||
|
||||
engine.disable_spec_dec()
|
||||
engine.generate(...) # Normal generation
|
||||
|
||||
engine.enable_spec_dec()
|
||||
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
|
||||
engine.clear_spec_dec()
|
||||
```
|
||||
"""
|
||||
if drafter_model is None and self.drafter is None:
|
||||
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||
if n_spec_tokens is not None:
|
||||
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
|
||||
self.n_spec_tokens = n_spec_tokens
|
||||
if drafter_model is not None:
|
||||
assert isinstance(drafter_model, nn.Module)
|
||||
# overwrite the drafter, if exists
|
||||
self.clear_spec_dec()
|
||||
self.drafter_model = drafter_model
|
||||
self.drafter = Drafter(
|
||||
self.drafter_model,
|
||||
self.tokenizer,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
# using speculative decoding for subsequent generations
|
||||
self.use_spec_dec = True
|
||||
|
||||
def disable_spec_dec(self) -> None:
|
||||
"""Disable using speculative decoding for subsequent generations."""
|
||||
# set back to the maximum number of tokens to speculate
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
self.use_spec_dec = False
|
||||
return
|
||||
|
||||
def clear_spec_dec(self) -> None:
|
||||
"""Clear relatable structures of speculative decoding, if exist."""
|
||||
if self.drafter_model or self.drafter:
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
torch.cuda.empty_cache()
|
||||
self.use_spec_dec = False
|
||||
return
|
||||
|
||||
def steps_spec_dec(self) -> List[Sequence]:
|
||||
"""
|
||||
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
|
||||
with many steps of speculating by a drafter model as well as verifying by a main model.
|
||||
|
||||
Returns:
|
||||
List[Sequence]: finished sequences generated by one step.
|
||||
"""
|
||||
batch = self.request_handler.schedule() # prefill batch
|
||||
batch.set_use_spec_dec(self.n_spec_tokens) # set batch to use-spec-dec mode
|
||||
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model
|
||||
|
||||
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
|
||||
drafter_out = self.drafter.speculate(input_ids, 1, None)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
|
||||
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
||||
logits = self.model(batch, self.k_cahce, self.v_cache)
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||
# append new inputs to the batch, temporarily
|
||||
batch.append_batch_tokens(next_tokens)
|
||||
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
||||
already_allocated_kv_len = batch.seq_lengths[0].item()
|
||||
input_ids = batch.get_1D_inputs_spec_dec(1)
|
||||
|
||||
batch.reset_use_spec_dec() # reset batch use-spec-dec mode
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
while True:
|
||||
# HACK Retrieve the running batch
|
||||
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
|
||||
batch = self.request_handler.running_bb # running batch
|
||||
batch.set_use_spec_dec(self.n_spec_tokens)
|
||||
|
||||
# 3. Decoding - Drafter model speculates `n` tokens
|
||||
drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
|
||||
for next_token_id_spec in next_token_ids_spec:
|
||||
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
|
||||
cur_length = batch.seq_lengths[0].item()
|
||||
if already_allocated_kv_len < cur_length:
|
||||
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
|
||||
already_allocated_kv_len = cur_length
|
||||
|
||||
# 4. Decoding - Main model verifies `n` tokens in parallel
|
||||
logits = self.model(batch, self.k_cahce, self.v_cache)
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||
|
||||
# 5. Compare and process the results
|
||||
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||
n_matches = self.n_spec_tokens if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
||||
# revoke appended tokens for each Sequence in the current batch
|
||||
batch.revoke_batch_tokens(self.n_spec_tokens - n_matches) # revoke drafted tokens
|
||||
# append the last correct token generated by the main model
|
||||
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
||||
input_ids = batch.get_1D_inputs_spec_dec(1)
|
||||
# trim past key values of the drafter model
|
||||
drafter_past_key_values = Drafter.trim_kv_cache(drafter_past_key_values, self.n_spec_tokens - n_matches - 1)
|
||||
|
||||
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
|
||||
finished_sequences = self.request_handler.update()
|
||||
if len(finished_sequences) > 0:
|
||||
break
|
||||
|
||||
batch.reset_use_spec_dec()
|
||||
|
||||
return finished_sequences
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str] = None,
|
||||
@@ -246,7 +383,6 @@ class InferenceEngine:
|
||||
List[str]: Inference result returned by one generation.
|
||||
"""
|
||||
with torch.inference_mode():
|
||||
self.generation_config = generation_config
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||
|
||||
@@ -257,8 +393,13 @@ class InferenceEngine:
|
||||
if generation_config is not None:
|
||||
self.generation_config = generation_config
|
||||
|
||||
while self.request_handler.check_unfinished_seqs():
|
||||
output_seqs_list += self.step()
|
||||
if self.use_spec_dec:
|
||||
assert self.drafter is not None, "Drafter Model is not initialized."
|
||||
while self.request_handler.check_unfinished_seqs():
|
||||
output_seqs_list += self.steps_spec_dec()
|
||||
else:
|
||||
while self.request_handler.check_unfinished_seqs():
|
||||
output_seqs_list += self.step()
|
||||
|
||||
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
||||
|
||||
@@ -428,7 +569,8 @@ class InferenceEngine:
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
|
Reference in New Issue
Block a user