ColossalAI/colossalai/inference/spec/drafter.py
flybird11111 ddbbbaab3e
[upgrade]Upgrade transformers (#6320)
* fix for async io

* test for upgrading transformers

* add ci machine

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_fp16_torch.py

* Update build_on_pr.yml

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fiux

* fix

* fix

* fix

* upgrade llama

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* upgrade_bert

* upgrade_bloom

* [upgrade] upgrade gpt2 (#6291)

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* upgrade command

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* add explanation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [upgrade]Upgrade qwen2 (#6302)

* upgrade qwen2

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* update_bloom

* fix

* add explantion

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* upgrade_sam

* add the explanation

* upgrade_t

* fix

* fix

* fix

* upgrade_gptj

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [upgrade]upgrade opt (#6307)

* upgrade opt

* fix

* [upgrade]Upgrade mixtral (#6317)

* upgrade mixtral

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* upgrade infer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* upgrade drafter

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* upgrade lazy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* upgrade mixtral

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [upgrade]Upgrade vit (#6308)

* fix

* fix

* fix rotate embedding test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [upgrade]upgrade mistral (#6296)

* upgrade mistral

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix falcon

* fix

* Update test_shard_deepseek.py

* Update build_on_pr.yml

* Update requirements.txt

* fix (#6327)

* fix (#6328)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update bert.py

* fix (#6329)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Hanks <hangxu0304@gmail.com>
Co-authored-by: wangbluo <2538539015@qq.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
2025-05-27 14:29:01 +08:00

124 lines
4.6 KiB
Python

from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers import PreTrainedTokenizer
from transformers.cache_utils import DynamicCache
from colossalai.utils import get_current_device
from .struct import DrafterOutput, GlideInput
class Drafter:
"""Container for the Drafter Model (Assistant Model) used in Speculative Decoding.
Args:
model (nn.Module): The drafter model.
tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model.
device (torch.device): The device for the drafter model.
"""
def __init__(
self,
model: nn.Module,
tokenizer: PreTrainedTokenizer,
device: torch.device = None,
dtype: torch.dtype = torch.float16,
):
self._tokenizer = tokenizer
self._device = device or get_current_device()
self._dtype = dtype
self._drafter_model = model.to(self._device)
self._drafter_model = model.to(self._dtype)
self._drafter_model.eval()
def get_model(self) -> nn.Module:
return self._drafter_model
@staticmethod
def trim_kv_cache(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], invalid_token_num: int
) -> Tuple[Tuple[torch.FloatTensor]]:
"""Trim the last `invalid_token_num` kv caches.
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values with shape
num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
invalid_token_num (int): The number of invalid tokens to trim.
"""
if past_key_values is None or invalid_token_num < 1:
return past_key_values
trimmed_past_key_values = []
for layer_idx in range(len(past_key_values)):
past_key_value = past_key_values[layer_idx]
trimmed_past_key_values.append(
(
past_key_value[0][:, :, :-invalid_token_num, :],
past_key_value[1][:, :, :-invalid_token_num, :],
)
)
past_key_values = tuple(trimmed_past_key_values)
return past_key_values
@torch.inference_mode()
def speculate(
self,
input_ids: torch.Tensor,
n_spec_tokens: int,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
glide_input: Optional[GlideInput] = None,
) -> DrafterOutput:
"""Generate n_spec_tokens tokens using the drafter model.
Args:
input_ids (torch.Tensor): Input token ids.
n_spec_tokens (int): Number of tokens to speculate.
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
glide_input (Optional[GlideInput]): The packed input for glimpsing kv caches of the main model,
when using the glide model as a drafter.
"""
assert n_spec_tokens >= 1, f"Invalid number {n_spec_tokens} to speculate"
# For compatibility with transformers of versions before 4.38.0
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
logits = []
token_ids = []
kwargs = {"return_dict": True, "use_cache": True}
if glide_input:
# required only when using glide model
kwargs["glide_input"] = glide_input
for _ in range(n_spec_tokens):
# update past key values
outputs = self._drafter_model(input_ids, past_key_values=past_key_values, **kwargs)
next_token_logits = outputs.logits[:, -1, :]
# NOTE Only use greedy search for speculating.
# As the drafter model usually has only a few layers with few parameters,
# introducing sampling will make the speculation unstable and lead to worse performance.
next_token_ids = torch.argmax(next_token_logits, dim=-1)
logits.append(next_token_logits)
token_ids.append(next_token_ids)
if next_token_ids.item() == self._tokenizer.eos_token_id:
# TODO(yuanheng-zhao) support bsz > 1
break
input_ids = next_token_ids[:, None]
past_key_values = outputs.past_key_values
speculated_length = len(token_ids) # For now, only support bsz 1
logits = torch.concat(logits, dim=0)
token_ids = torch.concat(token_ids, dim=-1)
if isinstance(past_key_values, DynamicCache):
past_key_values = past_key_values.to_legacy_cache()
out = DrafterOutput(
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values
)
return out