mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[Fix] Fix Inference Example, Tests, and Requirements (#5688)
* clean requirements * modify example inference struct * add test ci scripts * mark test_infer as submodule * rm deprecated cls & deps * import of HAS_FLASH_ATTN * prune inference tests to be run * prune triton kernel tests * increment pytest timeout mins * revert import path in openmoe
This commit is contained in:
@@ -1,11 +1,7 @@
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Tuple, Union
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from ordered_set import OrderedSet
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
@@ -170,242 +166,6 @@ class Sequence:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchInfo:
|
||||
"""
|
||||
Information to be passed and used for a batch of sequences.
|
||||
"""
|
||||
|
||||
max_batch_size: int
|
||||
kv_max_split_num: int
|
||||
num_heads: int
|
||||
head_dim: int
|
||||
sequences_set: OrderedSet[Sequence] = None
|
||||
is_prompts: bool = True
|
||||
device: torch.device = None
|
||||
dtype: torch.dtype = None
|
||||
fd_inter_tensor: FDIntermTensors = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.device is None:
|
||||
self.device = torch.cuda.current_device()
|
||||
if self.sequences_set is None:
|
||||
self.sequences_set = OrderedSet()
|
||||
if self.fd_inter_tensor is None:
|
||||
self.fd_inter_tensor = FDIntermTensors()
|
||||
|
||||
def init_fd_tensors(self):
|
||||
if not self.fd_inter_tensor.is_initialized:
|
||||
self.fd_inter_tensor.initialize(
|
||||
max_batch_size=self.max_batch_size,
|
||||
num_attn_heads=self.num_heads,
|
||||
kv_max_split_num=self.kv_max_split_num,
|
||||
head_dim=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def get_block_table_tensor(self) -> None:
|
||||
tesnor_list = []
|
||||
block_table = None
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
block_table = seq.block_table
|
||||
assert (
|
||||
block_table is not None
|
||||
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
||||
tesnor_list.append(seq.block_table)
|
||||
|
||||
block_table = torch.stack(tesnor_list)
|
||||
return block_table
|
||||
|
||||
def clear_batch(self) -> None:
|
||||
"""
|
||||
Clear sequence set and block table if we need to abort this batch.
|
||||
Prefill: clear sequence set and move them to running batch(external)
|
||||
Decoding: mark unfinished sequences as aborted.
|
||||
"""
|
||||
if self.is_prompts:
|
||||
self.sequences_set.clear()
|
||||
else:
|
||||
for seq in self.sequences_set:
|
||||
seq.mark_aborted()
|
||||
if seq.check_finish():
|
||||
seq.mark_finished()
|
||||
|
||||
self.sequences_set.clear()
|
||||
|
||||
def fliter_batch(self) -> List["Sequence"]:
|
||||
"""
|
||||
Remove completed sentences from a batch.
|
||||
|
||||
Returns:
|
||||
List["Sequence"]: List of finished sequences.
|
||||
"""
|
||||
finish_seqs = []
|
||||
for seq in self.sequences_set:
|
||||
if seq.check_finish():
|
||||
finish_seqs.append(seq)
|
||||
for finish_seq in finish_seqs:
|
||||
self.sequences_set.discard(finish_seq)
|
||||
return finish_seqs
|
||||
|
||||
def abort_seq(self, seq: "Sequence") -> "Sequence":
|
||||
"""
|
||||
Remove sequence from the batch.
|
||||
"""
|
||||
if not seq.check_finish():
|
||||
seq.status = RequestStatus.ABORTED
|
||||
self.sequences_set.discard(seq)
|
||||
return seq
|
||||
|
||||
def add_seqs(self, seqs: Union[Sequence, List[Sequence]]) -> None:
|
||||
"""
|
||||
Add new sequence to batch
|
||||
|
||||
Args:
|
||||
seqs (List["Sequence"]): The list of new sequences.
|
||||
"""
|
||||
# covnert single sequence to list
|
||||
if isinstance(seqs, Sequence):
|
||||
seqs = [seqs]
|
||||
|
||||
for seq in seqs:
|
||||
if seq in self.sequences_set:
|
||||
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||
continue
|
||||
self.sequences_set.add(seq)
|
||||
|
||||
def del_seq(self, seq: Sequence) -> Sequence:
|
||||
"""
|
||||
Delete sequence in batch
|
||||
"""
|
||||
self.sequences_set.discard(seq)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> None:
|
||||
"""
|
||||
Check whether sequences_set is empty.
|
||||
"""
|
||||
return not self.sequences_set
|
||||
|
||||
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None:
|
||||
"""
|
||||
Add an output token for each sentence in the batch.
|
||||
|
||||
Args:
|
||||
tokens (List[int]): A batch of tokens
|
||||
"""
|
||||
|
||||
if isinstance(tokens, torch.Tensor):
|
||||
tokens = tokens.tolist()
|
||||
|
||||
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
|
||||
|
||||
for seq, token in zip(self.sequences_set, tokens):
|
||||
if not isinstance(token, list):
|
||||
if not isinstance(token, int):
|
||||
raise TypeError(f"The token type must be List[int] or int, but got {type(token)}.")
|
||||
token = [token]
|
||||
seq.output_token_id += token
|
||||
seq.check_finish()
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
"""
|
||||
Get batch_size of this batch
|
||||
"""
|
||||
return len(self.sequences_set)
|
||||
|
||||
def get_batch_inputs(self) -> torch.LongTensor:
|
||||
"""
|
||||
Get bacth inputs for forward inference computation.
|
||||
"""
|
||||
|
||||
input_list = []
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
if self.is_prompts:
|
||||
if seq.output_len > 0:
|
||||
input_list.append(seq.input_token_id + seq.output_token_id)
|
||||
else:
|
||||
input_list.append(seq.input_token_id)
|
||||
else:
|
||||
input_list.append([seq.output_token_id[-1]])
|
||||
|
||||
max_seq_len = max(len(sub_list) for sub_list in input_list)
|
||||
|
||||
# We assume that all the padding_id in seq are the same at present.
|
||||
return _make_tensor_with_pad(input_list, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int)
|
||||
|
||||
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
|
||||
"""
|
||||
Flattening the input tokens.
|
||||
"""
|
||||
input_list = []
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
if self.is_prompts:
|
||||
input_list.extend(seq.input_token_id)
|
||||
else:
|
||||
input_list.append(seq.output_token_id[-1])
|
||||
|
||||
return torch.tensor(input_list, dtype=torch.long, device=self.device)
|
||||
|
||||
def get_sequence_lengths(self):
|
||||
"""
|
||||
Get the input_len of each sentence in this batch.
|
||||
"""
|
||||
len_list = []
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
len_list.append(seq.sentence_len)
|
||||
|
||||
return torch.tensor(len_list, dtype=torch.int, device=self.device)
|
||||
|
||||
def get_attn_mask(self) -> torch.Tensor:
|
||||
"""
|
||||
Generate and return attention mask.
|
||||
"""
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
past_values = []
|
||||
# We assume that all the padding_id in seq are the same at present.
|
||||
padding_id = self.sequences_set[0].pad_token_id
|
||||
|
||||
for seq in self.sequences_set:
|
||||
past_values.append(seq.input_token_id + seq.output_token_id)
|
||||
|
||||
max_seq_len = max(len(sub_list) for sub_list in past_values)
|
||||
attn_mask = _make_tensor_with_pad(
|
||||
past_values, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int, device=self.device
|
||||
)
|
||||
|
||||
return attn_mask.ne(padding_id).long()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
|
||||
|
||||
|
||||
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||
assert len(x) <= max_len
|
||||
return [pad] * (max_len - len(x)) + x
|
||||
|
||||
|
||||
def _make_tensor_with_pad(
|
||||
x: Union[List[List[int]], List[int]],
|
||||
max_len: int,
|
||||
pad: int,
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
pin_memory: bool = False,
|
||||
):
|
||||
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
|
||||
return torch.tensor(padded_x, dtype=dtype, device=device, pin_memory=pin_memory and str(device) == "cpu")
|
||||
|
Reference in New Issue
Block a user