Fixed a bug in the inference frame

This commit is contained in:
yuehuayingxueluo
2023-12-26 21:34:27 +08:00
committed by FrankLeeeee
parent 86853a37d5
commit 62fd08ee44
8 changed files with 261 additions and 90 deletions

View File

@@ -1,6 +1,6 @@
import enum
from dataclasses import dataclass
from typing import Any, List, Union
from typing import Any, List, Tuple, Union
import torch
from ordered_set import OrderedSet
@@ -74,13 +74,6 @@ class Sequence:
self.output_token_id = []
self.status = RequestStatus.WAITING
@property
def prompt_len(self) -> int:
"""
Get length of prompts
"""
return len(self.input_token_id)
@property
def sentence_len(self) -> int:
"""
@@ -113,7 +106,7 @@ class Sequence:
return True
if self.output_token_id:
if self.output_token_id[-1] == self.eos_token_id or len(self.output_token_id) == self.max_output_len:
if self.output_token_id[-1] >= self.eos_token_id or len(self.output_token_id) == self.max_output_len:
self.status = RequestStatus.COMPLETED
return True
@@ -143,11 +136,13 @@ class Sequence:
def __repr__(self) -> str:
return (
f"Request ID(request_id={self.request_id}, "
f"(request_id={self.request_id}, "
f"prompt={self.prompt}, "
f"status={self.status.name}, "
f"sample_params={self.sample_params}, "
f"logical block number={len(self.block_table_index)}"
f"logical_block_number={self.block_table.shape[0]},"
f"input_len={self.input_len}),"
f"output_len={self.output_len})"
)
@@ -159,9 +154,15 @@ class BatchInfo:
sequences_set: OrderedSet["Sequence"] = None
is_prompts: bool = True
device: torch.device = None
@classmethod
def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo":
def __post_init__(self):
if self.device is None:
self.device = torch.cuda.current_device()
if self.sequences_set is None:
self.sequences_set = OrderedSet()
def init_batch(self, seqs: List["Sequence"] = None):
"""
Initializes inference batches by input sentence list.
@@ -169,29 +170,29 @@ class BatchInfo:
seqs (List["Sequence"]): List of input sequence.
"""
sequences_set = OrderedSet()
assert len(self.sequences_set) == 0, "Sequences set has been initialized."
if seqs is not None:
if not isinstance(seqs, list):
seqs = [seqs]
for seq in seqs:
if seq in sequences_set:
if seq in self.sequences_set:
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
continue
sequences_set.add(seq)
return cls(sequences_set=sequences_set)
self.sequences_set.add(seq)
def get_block_table_tensor(self) -> None:
tesnor_list = []
block_table = None
for seq in self.sequences_set:
block_table = seq.block_table
assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
assert (
block_table is not None
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
tesnor_list.append(seq.block_table)
assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first."
block_table = torch.concat(tesnor_list)
block_table = torch.stack(tesnor_list)
return block_table
def clear_batch(self) -> None:
@@ -239,7 +240,7 @@ class BatchInfo:
seqs = [seqs]
for seq in seqs:
if seq in self.sequences_set:
if self.sequences_set and seq in self.sequences_set:
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
continue
self.sequences_set.add(seq)
@@ -251,7 +252,7 @@ class BatchInfo:
"""
return not self.sequences_set
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]]]) -> None:
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None:
"""
Add an output token for each sentence in the batch.
@@ -259,6 +260,9 @@ class BatchInfo:
tokens (List[int]): A batch of tokens
"""
if isinstance(tokens, torch.Tensor):
tokens = tokens.tolist()
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
for seq, token in zip(self.sequences_set, tokens):
@@ -287,19 +291,25 @@ class BatchInfo:
else:
input_list.append([seq.output_token_id[-1]])
return torch.tensor(input_list, dtype=torch.long)
return torch.tensor(input_list, dtype=torch.long, device=self.device)
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
"""
Flattening the input tokens.
"""
input_list = []
input_len_list = []
for seq in self.sequences_set:
if self.is_prompts:
input_list.extend(seq.input_token_id)
input_len_list.append(seq.sentence_len)
else:
input_list.append(seq.output_token_id[-1])
return torch.tensor(input_list, dtype=torch.long)
input_len_list.append(1)
return torch.tensor(input_list, dtype=torch.long, device=self.device), torch.tensor(
input_len_list, dtype=torch.int, device=device
)
def get_sequence_lengths(self):
"""
@@ -307,5 +317,9 @@ class BatchInfo:
"""
len_list = []
for seq in self.sequences_set:
len_list.append(seq.get_sentence_len())
return torch.tensor(len_list, dtype=torch.int)
len_list.append(seq.sentence_len)
return torch.tensor(len_list, dtype=torch.int, device=self.device)
def __repr__(self) -> str:
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"