mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn * opt tmp tensor * fix benchmark_llama * fix code style * fix None logic for output tensor * fix adapted to get_xine_cache * add comment * fix ci bugs * fix some codes * rm duplicated codes * rm duplicated codes * fix code style * add _get_dtype in config.py
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Any, List, Tuple, Union
|
||||
import torch
|
||||
from ordered_set import OrderedSet
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
@@ -61,6 +62,7 @@ class Sequence:
|
||||
sample_params (SampleParams): The sample_params of input sequence.
|
||||
block_table (torch.Tensor): The index of input sequence in block_table.
|
||||
eos_token_id (int): The eos token id for this inference process.
|
||||
pad_token_id (int): The pad token id for this inference process.
|
||||
max_output_len (int): Maximum output length.
|
||||
"""
|
||||
|
||||
@@ -71,6 +73,7 @@ class Sequence:
|
||||
sample_params: Any # SampleParams needs to be imported later.
|
||||
block_table: torch.Tensor
|
||||
eos_token_id: int
|
||||
pad_token_id: int
|
||||
max_output_len: int = 256
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -167,15 +170,23 @@ class BatchInfo:
|
||||
Information to be passed and used for a batch of sequences.
|
||||
"""
|
||||
|
||||
max_batch_size: int
|
||||
kv_max_split_num: int
|
||||
num_heads: int
|
||||
head_dim: int
|
||||
sequences_set: OrderedSet[Sequence] = None
|
||||
is_prompts: bool = True
|
||||
device: torch.device = None
|
||||
dtype: torch.dtype = None
|
||||
fd_inter_tensor: FDIntermTensors = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.device is None:
|
||||
self.device = torch.cuda.current_device()
|
||||
if self.sequences_set is None:
|
||||
self.sequences_set = OrderedSet()
|
||||
if self.fd_inter_tensor is None:
|
||||
self.fd_inter_tensor = FDIntermTensors()
|
||||
|
||||
def init_batch(self, seqs: List["Sequence"] = None):
|
||||
"""
|
||||
@@ -185,8 +196,6 @@ class BatchInfo:
|
||||
seqs (List["Sequence"]): List of input sequence.
|
||||
"""
|
||||
|
||||
assert len(self.sequences_set) == 0, "Sequences set has been initialized."
|
||||
|
||||
if seqs is not None:
|
||||
if not isinstance(seqs, list):
|
||||
seqs = [seqs]
|
||||
@@ -197,16 +206,30 @@ class BatchInfo:
|
||||
|
||||
self.sequences_set.add(seq)
|
||||
|
||||
def init_fd_tensors(self):
|
||||
if not self.fd_inter_tensor.is_initialized:
|
||||
self.fd_inter_tensor.initialize(
|
||||
max_batch_size=self.max_batch_size,
|
||||
num_attn_heads=self.num_heads,
|
||||
kv_max_split_num=self.kv_max_split_num,
|
||||
head_dim=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def get_block_table_tensor(self) -> None:
|
||||
tesnor_list = []
|
||||
block_table = None
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
block_table = seq.block_table
|
||||
assert (
|
||||
block_table is not None
|
||||
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
||||
tesnor_list.append(seq.block_table)
|
||||
assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
block_table = torch.stack(tesnor_list)
|
||||
return block_table
|
||||
|
||||
@@ -218,7 +241,6 @@ class BatchInfo:
|
||||
"""
|
||||
if self.is_prompts:
|
||||
self.sequences_set.clear()
|
||||
|
||||
else:
|
||||
for seq in self.sequences_set:
|
||||
seq.mark_aborted()
|
||||
@@ -312,14 +334,14 @@ class BatchInfo:
|
||||
"""
|
||||
Get bacth inputs for forward inference computation.
|
||||
"""
|
||||
|
||||
input_list = []
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
if self.is_prompts:
|
||||
if seq.output_len > 0:
|
||||
print(seq.output_token_id)
|
||||
seq_data = seq.input_token_id + seq.output_token_id
|
||||
print(seq_data)
|
||||
input_list.append(seq.input_token_id + seq.output_token_id)
|
||||
else:
|
||||
input_list.append(seq.input_token_id)
|
||||
@@ -328,7 +350,8 @@ class BatchInfo:
|
||||
|
||||
max_seq_len = max(len(sub_list) for sub_list in input_list)
|
||||
|
||||
return _make_tensor_with_pad(input_list, max_seq_len, 0, dtype=torch.int)
|
||||
# We assume that all the padding_id in seq are the same at present.
|
||||
return _make_tensor_with_pad(input_list, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int)
|
||||
|
||||
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -336,6 +359,9 @@ class BatchInfo:
|
||||
"""
|
||||
input_list = []
|
||||
input_len_list = []
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
if self.is_prompts:
|
||||
input_list.extend(seq.input_token_id)
|
||||
@@ -353,16 +379,23 @@ class BatchInfo:
|
||||
Get the input_len of each sentence in this batch.
|
||||
"""
|
||||
len_list = []
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
len_list.append(seq.sentence_len)
|
||||
|
||||
return torch.tensor(len_list, dtype=torch.int, device=self.device)
|
||||
|
||||
def get_attn_mask(self, padding_id: int) -> torch.Tensor:
|
||||
def get_attn_mask(self) -> torch.Tensor:
|
||||
"""
|
||||
Generate and return attention mask.
|
||||
"""
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
past_values = []
|
||||
# We assume that all the padding_id in seq are the same at present.
|
||||
padding_id = self.sequences_set[0].pad_token_id
|
||||
|
||||
for seq in self.sequences_set:
|
||||
past_values.append(seq.input_token_id + seq.output_token_id)
|
||||
@@ -378,7 +411,7 @@ class BatchInfo:
|
||||
|
||||
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||
assert len(x) <= max_len
|
||||
return x + [pad] * (max_len - len(x))
|
||||
return [pad] * (max_len - len(x)) + x
|
||||
|
||||
|
||||
def _make_tensor_with_pad(
|
||||
|
Reference in New Issue
Block a user