mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Inference/SpecDec] Support GLIDE Drafter Model (#5455)
* add glide-llama policy and modeling * update glide modeling, compitable with transformers 4.36.2 * revise glide llama modeling/usage * fix issues of glimpsing large kv * revise the way re-loading params for glide drafter * fix drafter and engine tests * enable convert to glide strict=False * revise glide llama modeling * revise vicuna prompt template * revise drafter and tests * apply usage of glide model in engine
This commit is contained in:
@@ -12,7 +12,7 @@ from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.spec import Drafter
|
||||
from colossalai.inference.spec import Drafter, GlideInput
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
@@ -72,6 +72,7 @@ class InferenceEngine:
|
||||
self.use_spec_dec = False
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
|
||||
if model_policy is None:
|
||||
@@ -229,7 +230,12 @@ class InferenceEngine:
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model
|
||||
|
||||
def enable_spec_dec(self, drafter_model: nn.Module = None, n_spec_tokens: int = None) -> None:
|
||||
def enable_spec_dec(
|
||||
self,
|
||||
drafter_model: nn.Module = None,
|
||||
n_spec_tokens: int = None,
|
||||
use_glide_drafter: bool = False,
|
||||
) -> None:
|
||||
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
|
||||
|
||||
Args:
|
||||
@@ -237,6 +243,8 @@ class InferenceEngine:
|
||||
If provided, the previous drafter and drafter model, if exist, will be overwritten.
|
||||
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
|
||||
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
|
||||
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
|
||||
If True, the drafter model will be replaced by a glide model.
|
||||
|
||||
```python
|
||||
...
|
||||
@@ -269,6 +277,22 @@ class InferenceEngine:
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# check if the provided drafter model is compatible with GLIDE structure
|
||||
# when `use_glide_drafter` is set to True
|
||||
if (
|
||||
use_glide_drafter
|
||||
and hasattr(drafter_model, "model")
|
||||
and hasattr(drafter_model.model, "layers")
|
||||
and hasattr(drafter_model.model.layers[0], "cross_attn")
|
||||
):
|
||||
self.use_glide = use_glide_drafter
|
||||
elif use_glide_drafter:
|
||||
self.logger.warning(
|
||||
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
|
||||
f"but the provided drafter model is not compatible with GLIDE structure."
|
||||
f"Falling back to use the default drafter model (non-GLIDE)."
|
||||
)
|
||||
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
|
||||
# using speculative decoding for subsequent generations
|
||||
self.use_spec_dec = True
|
||||
@@ -278,6 +302,7 @@ class InferenceEngine:
|
||||
self.request_handler.unset_spec_dec_mode()
|
||||
# set back to the maximum number of tokens to speculate
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def clear_spec_dec(self) -> None:
|
||||
@@ -288,6 +313,7 @@ class InferenceEngine:
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
torch.cuda.empty_cache()
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def steps_spec_dec(self) -> List[Sequence]:
|
||||
@@ -304,6 +330,7 @@ class InferenceEngine:
|
||||
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model
|
||||
|
||||
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
|
||||
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
|
||||
drafter_out = self.drafter.speculate(input_ids, 1, None)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
@@ -326,7 +353,21 @@ class InferenceEngine:
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
|
||||
# 3. Decoding - Drafter model speculates `n` tokens
|
||||
drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values)
|
||||
glide_input = None
|
||||
if self.use_glide:
|
||||
glide_input = GlideInput(
|
||||
batch.get_block_table_tensor(),
|
||||
self.k_cahce[-1], # use kv cahces of the last layer
|
||||
self.v_cache[-1],
|
||||
batch.get_sequence_lengths(),
|
||||
)
|
||||
|
||||
drafter_out = self.drafter.speculate(
|
||||
input_ids,
|
||||
self.n_spec_tokens,
|
||||
drafter_past_key_values,
|
||||
glide_input=glide_input,
|
||||
)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
drafter_spec_length = drafter_out.speculated_length
|
||||
@@ -339,6 +380,8 @@ class InferenceEngine:
|
||||
already_allocated_kv_len = cur_length
|
||||
|
||||
# 4. Decoding - Main model verifies `n` tokens in parallel
|
||||
if drafter_spec_length < batch.num_tokens_to_verify:
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
|
||||
logits = self.model(batch, self.k_cahce, self.v_cache)
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||
|
||||
@@ -348,6 +391,7 @@ class InferenceEngine:
|
||||
|
||||
# revoke appended tokens for each Sequence in the current batch
|
||||
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
|
||||
|
||||
# append the last correct token generated by the main model
|
||||
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
||||
|
||||
@@ -355,6 +399,7 @@ class InferenceEngine:
|
||||
drafter_past_key_values = Drafter.trim_kv_cache(
|
||||
drafter_past_key_values, drafter_spec_length - n_matches - 1
|
||||
)
|
||||
|
||||
# prepare inputs for the next round of speculation
|
||||
n = 1 if n_matches < drafter_spec_length else 2
|
||||
input_ids = batch.get_1D_inputs_spec_dec(n)
|
||||
@@ -364,6 +409,11 @@ class InferenceEngine:
|
||||
if len(finished_sequences) > 0:
|
||||
break
|
||||
|
||||
# Reset back the number of speculated tokens of the batch,
|
||||
# this is used to handle the last round of speculation, in which case the number of speculated tokens
|
||||
# by the drafter is less than the number of speculated tokens set to the engine.
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
|
||||
|
||||
return finished_sequences
|
||||
|
||||
def generate(
|
||||
|
Reference in New Issue
Block a user