[Inference/SpecDec] Add Basic Drafter Model Container (#5405)

* [Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399)

fix dependency in pytest

* add drafter model container (basic ver)
This commit is contained in:
Yuanheng Zhao
2024-02-28 13:48:17 +08:00
committed by Yuanheng
parent d63c469f45
commit 5a9b05f7b2
4 changed files with 216 additions and 0 deletions

View File

@@ -0,0 +1,29 @@
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
@dataclass
class DrafterOutput:
"""
Dataclass for drafter model outputs.
Args:
speculated_length (int): Speculated length of the output sequence
It is always less than or equal to spec_num during drafter's speculation process
logits (torch.FloatTensor): Logits of the output sequence
next_tokens (torch.Tensor): Next token ids
past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]]): Past key values of the output sequence
"""
speculated_length: int = None
logits: torch.FloatTensor = None
next_tokens: torch.Tensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
def __post_init__(self):
assert self.speculated_length is not None and self.speculated_length >= 0
if self.past_key_values is not None:
assert isinstance(self.past_key_values, tuple), "Past key values should be a tuple"
assert all([isinstance(past_key_value, tuple) for past_key_value in self.past_key_values])