[Inference/SpecDec] Support GLIDE Drafter Model (#5455)

* add glide-llama policy and modeling

* update glide modeling, compitable with transformers 4.36.2

* revise glide llama modeling/usage

* fix issues of glimpsing large kv

* revise the way re-loading params for glide drafter

* fix drafter and engine tests

* enable convert to glide strict=False

* revise glide llama modeling

* revise vicuna prompt template

* revise drafter and tests

* apply usage of glide model in engine
This commit is contained in:
Yuanheng Zhao
2024-04-01 21:54:24 +08:00
committed by Yuanheng
parent 912e24b2aa
commit d85d91435a
10 changed files with 722 additions and 82 deletions

View File

@@ -6,7 +6,7 @@ from transformers import PreTrainedTokenizer
from colossalai.utils import get_current_device
from .struct import DrafterOutput
from .struct import DrafterOutput, GlideInput
class Drafter:
@@ -66,6 +66,7 @@ class Drafter:
input_ids: torch.Tensor,
n_spec_tokens: int,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
glide_input: Optional[GlideInput] = None,
) -> DrafterOutput:
"""Generate n_spec_tokens tokens using the drafter model.
@@ -73,6 +74,8 @@ class Drafter:
input_ids (torch.Tensor): Input token ids.
n_spec_tokens (int): Number of tokens to speculate.
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
glide_input (Optional[GlideInput]): The packed input for glimpsing kv caches of the main model,
when using the glide model as a drafter.
"""
assert n_spec_tokens >= 1, f"Invalid number {n_spec_tokens} to speculate"
@@ -83,13 +86,16 @@ class Drafter:
logits = []
token_ids = []
kwargs = {"return_dict": True, "use_cache": True}
if glide_input:
# required only when using glide model
kwargs["glide_input"] = glide_input
for _ in range(n_spec_tokens):
outputs = self._drafter_model(
input_ids,
return_dict=True,
use_cache=True,
past_key_values=past_key_values,
)
# update past key values
kwargs["past_key_values"] = past_key_values
outputs = self._drafter_model(input_ids, **kwargs)
next_token_logits = outputs.logits[:, -1, :]
# NOTE Only use greedy search for speculating.
@@ -100,12 +106,12 @@ class Drafter:
logits.append(next_token_logits)
token_ids.append(next_token_ids)
if next_token_ids.item() == self._tokenizer.eos_token_id:
# TODO support bsz > 1
# TODO(yuanheng-zhao) support bsz > 1
break
input_ids = next_token_ids[:, None]
past_key_values = outputs.past_key_values
speculated_length = len(token_ids) # TODO For now, only support bsz 1
speculated_length = len(token_ids) # For now, only support bsz 1
logits = torch.concat(logits, dim=0)
token_ids = torch.concat(token_ids, dim=-1)