mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[Inference/SpecDec] Add Basic Drafter Model Container (#5405)
* [Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399) fix dependency in pytest * add drafter model container (basic ver)
This commit is contained in:
29
colossalai/inference/spec/struct.py
Normal file
29
colossalai/inference/spec/struct.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class DrafterOutput:
|
||||
"""
|
||||
Dataclass for drafter model outputs.
|
||||
|
||||
Args:
|
||||
speculated_length (int): Speculated length of the output sequence
|
||||
It is always less than or equal to spec_num during drafter's speculation process
|
||||
logits (torch.FloatTensor): Logits of the output sequence
|
||||
next_tokens (torch.Tensor): Next token ids
|
||||
past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]]): Past key values of the output sequence
|
||||
"""
|
||||
|
||||
speculated_length: int = None
|
||||
logits: torch.FloatTensor = None
|
||||
next_tokens: torch.Tensor = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.speculated_length is not None and self.speculated_length >= 0
|
||||
if self.past_key_values is not None:
|
||||
assert isinstance(self.past_key_values, tuple), "Past key values should be a tuple"
|
||||
assert all([isinstance(past_key_value, tuple) for past_key_value in self.past_key_values])
|
Reference in New Issue
Block a user