mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
Fixed a bug in the inference frame
This commit is contained in:
committed by
FrankLeeeee
parent
86853a37d5
commit
62fd08ee44
@@ -1,6 +1,6 @@
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from ordered_set import OrderedSet
|
||||
@@ -74,13 +74,6 @@ class Sequence:
|
||||
self.output_token_id = []
|
||||
self.status = RequestStatus.WAITING
|
||||
|
||||
@property
|
||||
def prompt_len(self) -> int:
|
||||
"""
|
||||
Get length of prompts
|
||||
"""
|
||||
return len(self.input_token_id)
|
||||
|
||||
@property
|
||||
def sentence_len(self) -> int:
|
||||
"""
|
||||
@@ -113,7 +106,7 @@ class Sequence:
|
||||
return True
|
||||
|
||||
if self.output_token_id:
|
||||
if self.output_token_id[-1] == self.eos_token_id or len(self.output_token_id) == self.max_output_len:
|
||||
if self.output_token_id[-1] >= self.eos_token_id or len(self.output_token_id) == self.max_output_len:
|
||||
self.status = RequestStatus.COMPLETED
|
||||
return True
|
||||
|
||||
@@ -143,11 +136,13 @@ class Sequence:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Request ID(request_id={self.request_id}, "
|
||||
f"(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt}, "
|
||||
f"status={self.status.name}, "
|
||||
f"sample_params={self.sample_params}, "
|
||||
f"logical block number={len(self.block_table_index)}"
|
||||
f"logical_block_number={self.block_table.shape[0]},"
|
||||
f"input_len={self.input_len}),"
|
||||
f"output_len={self.output_len})"
|
||||
)
|
||||
|
||||
|
||||
@@ -159,9 +154,15 @@ class BatchInfo:
|
||||
|
||||
sequences_set: OrderedSet["Sequence"] = None
|
||||
is_prompts: bool = True
|
||||
device: torch.device = None
|
||||
|
||||
@classmethod
|
||||
def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo":
|
||||
def __post_init__(self):
|
||||
if self.device is None:
|
||||
self.device = torch.cuda.current_device()
|
||||
if self.sequences_set is None:
|
||||
self.sequences_set = OrderedSet()
|
||||
|
||||
def init_batch(self, seqs: List["Sequence"] = None):
|
||||
"""
|
||||
Initializes inference batches by input sentence list.
|
||||
|
||||
@@ -169,29 +170,29 @@ class BatchInfo:
|
||||
seqs (List["Sequence"]): List of input sequence.
|
||||
"""
|
||||
|
||||
sequences_set = OrderedSet()
|
||||
assert len(self.sequences_set) == 0, "Sequences set has been initialized."
|
||||
|
||||
if seqs is not None:
|
||||
if not isinstance(seqs, list):
|
||||
seqs = [seqs]
|
||||
for seq in seqs:
|
||||
if seq in sequences_set:
|
||||
if seq in self.sequences_set:
|
||||
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||
continue
|
||||
|
||||
sequences_set.add(seq)
|
||||
|
||||
return cls(sequences_set=sequences_set)
|
||||
self.sequences_set.add(seq)
|
||||
|
||||
def get_block_table_tensor(self) -> None:
|
||||
tesnor_list = []
|
||||
block_table = None
|
||||
for seq in self.sequences_set:
|
||||
block_table = seq.block_table
|
||||
assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
||||
assert (
|
||||
block_table is not None
|
||||
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
||||
tesnor_list.append(seq.block_table)
|
||||
assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first."
|
||||
block_table = torch.concat(tesnor_list)
|
||||
block_table = torch.stack(tesnor_list)
|
||||
return block_table
|
||||
|
||||
def clear_batch(self) -> None:
|
||||
@@ -239,7 +240,7 @@ class BatchInfo:
|
||||
seqs = [seqs]
|
||||
|
||||
for seq in seqs:
|
||||
if seq in self.sequences_set:
|
||||
if self.sequences_set and seq in self.sequences_set:
|
||||
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||
continue
|
||||
self.sequences_set.add(seq)
|
||||
@@ -251,7 +252,7 @@ class BatchInfo:
|
||||
"""
|
||||
return not self.sequences_set
|
||||
|
||||
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]]]) -> None:
|
||||
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None:
|
||||
"""
|
||||
Add an output token for each sentence in the batch.
|
||||
|
||||
@@ -259,6 +260,9 @@ class BatchInfo:
|
||||
tokens (List[int]): A batch of tokens
|
||||
"""
|
||||
|
||||
if isinstance(tokens, torch.Tensor):
|
||||
tokens = tokens.tolist()
|
||||
|
||||
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
|
||||
|
||||
for seq, token in zip(self.sequences_set, tokens):
|
||||
@@ -287,19 +291,25 @@ class BatchInfo:
|
||||
else:
|
||||
input_list.append([seq.output_token_id[-1]])
|
||||
|
||||
return torch.tensor(input_list, dtype=torch.long)
|
||||
return torch.tensor(input_list, dtype=torch.long, device=self.device)
|
||||
|
||||
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
|
||||
"""
|
||||
Flattening the input tokens.
|
||||
"""
|
||||
input_list = []
|
||||
input_len_list = []
|
||||
for seq in self.sequences_set:
|
||||
if self.is_prompts:
|
||||
input_list.extend(seq.input_token_id)
|
||||
input_len_list.append(seq.sentence_len)
|
||||
else:
|
||||
input_list.append(seq.output_token_id[-1])
|
||||
return torch.tensor(input_list, dtype=torch.long)
|
||||
input_len_list.append(1)
|
||||
|
||||
return torch.tensor(input_list, dtype=torch.long, device=self.device), torch.tensor(
|
||||
input_len_list, dtype=torch.int, device=device
|
||||
)
|
||||
|
||||
def get_sequence_lengths(self):
|
||||
"""
|
||||
@@ -307,5 +317,9 @@ class BatchInfo:
|
||||
"""
|
||||
len_list = []
|
||||
for seq in self.sequences_set:
|
||||
len_list.append(seq.get_sentence_len())
|
||||
return torch.tensor(len_list, dtype=torch.int)
|
||||
len_list.append(seq.sentence_len)
|
||||
|
||||
return torch.tensor(len_list, dtype=torch.int, device=self.device)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
|
||||
|
Reference in New Issue
Block a user