mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[Inference/SpecDec] Support GLIDE Drafter Model (#5455)
* add glide-llama policy and modeling * update glide modeling, compitable with transformers 4.36.2 * revise glide llama modeling/usage * fix issues of glimpsing large kv * revise the way re-loading params for glide drafter * fix drafter and engine tests * enable convert to glide strict=False * revise glide llama modeling * revise vicuna prompt template * revise drafter and tests * apply usage of glide model in engine
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from .drafter import Drafter
|
||||
from .struct import DrafterOutput
|
||||
from .struct import DrafterOutput, GlideInput
|
||||
|
||||
__all__ = ["Drafter", "DrafterOutput"]
|
||||
__all__ = ["Drafter", "DrafterOutput", "GlideInput"]
|
||||
|
@@ -6,7 +6,7 @@ from transformers import PreTrainedTokenizer
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .struct import DrafterOutput
|
||||
from .struct import DrafterOutput, GlideInput
|
||||
|
||||
|
||||
class Drafter:
|
||||
@@ -66,6 +66,7 @@ class Drafter:
|
||||
input_ids: torch.Tensor,
|
||||
n_spec_tokens: int,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
glide_input: Optional[GlideInput] = None,
|
||||
) -> DrafterOutput:
|
||||
"""Generate n_spec_tokens tokens using the drafter model.
|
||||
|
||||
@@ -73,6 +74,8 @@ class Drafter:
|
||||
input_ids (torch.Tensor): Input token ids.
|
||||
n_spec_tokens (int): Number of tokens to speculate.
|
||||
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
|
||||
glide_input (Optional[GlideInput]): The packed input for glimpsing kv caches of the main model,
|
||||
when using the glide model as a drafter.
|
||||
"""
|
||||
assert n_spec_tokens >= 1, f"Invalid number {n_spec_tokens} to speculate"
|
||||
|
||||
@@ -83,13 +86,16 @@ class Drafter:
|
||||
logits = []
|
||||
token_ids = []
|
||||
|
||||
kwargs = {"return_dict": True, "use_cache": True}
|
||||
if glide_input:
|
||||
# required only when using glide model
|
||||
kwargs["glide_input"] = glide_input
|
||||
|
||||
for _ in range(n_spec_tokens):
|
||||
outputs = self._drafter_model(
|
||||
input_ids,
|
||||
return_dict=True,
|
||||
use_cache=True,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
# update past key values
|
||||
kwargs["past_key_values"] = past_key_values
|
||||
|
||||
outputs = self._drafter_model(input_ids, **kwargs)
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# NOTE Only use greedy search for speculating.
|
||||
@@ -100,12 +106,12 @@ class Drafter:
|
||||
logits.append(next_token_logits)
|
||||
token_ids.append(next_token_ids)
|
||||
if next_token_ids.item() == self._tokenizer.eos_token_id:
|
||||
# TODO support bsz > 1
|
||||
# TODO(yuanheng-zhao) support bsz > 1
|
||||
break
|
||||
input_ids = next_token_ids[:, None]
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
speculated_length = len(token_ids) # TODO For now, only support bsz 1
|
||||
speculated_length = len(token_ids) # For now, only support bsz 1
|
||||
logits = torch.concat(logits, dim=0)
|
||||
token_ids = torch.concat(token_ids, dim=-1)
|
||||
|
||||
|
@@ -27,3 +27,29 @@ class DrafterOutput:
|
||||
if self.past_key_values is not None:
|
||||
assert isinstance(self.past_key_values, tuple), "Past key values should be a tuple"
|
||||
assert all([isinstance(past_key_value, tuple) for past_key_value in self.past_key_values])
|
||||
|
||||
|
||||
@dataclass
|
||||
class GlideInput:
|
||||
"""Dataclass for Glide Models (e.g. `colossalai/inference/modeling/models/glide_llama.py`).
|
||||
Used for pack data that will be used during glimpsing KV Caches of the main model.
|
||||
|
||||
Args:
|
||||
block_tables (torch.Tensor): [num_seqs, max_blocks_per_seq] The block table of KV Caches.
|
||||
large_k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_size]
|
||||
Blocked key cache of the main model
|
||||
large_v_cache (torch.Tensor): Blocked value cache of the main model. It has the same shape as k cache.
|
||||
sequence_lengths (torch.Tensor): [num_seqs] Sequence lengths of the current batch.
|
||||
"""
|
||||
|
||||
block_tables: torch.Tensor = None
|
||||
large_k_cache: torch.Tensor = None
|
||||
large_v_cache: torch.Tensor = None
|
||||
sequence_lengths: torch.Tensor = None
|
||||
|
||||
@property
|
||||
def glimpse_ready(self):
|
||||
return all(
|
||||
attr is not None
|
||||
for attr in [self.block_tables, self.large_k_cache, self.large_v_cache, self.sequence_lengths]
|
||||
)
|
||||
|
Reference in New Issue
Block a user