mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn * opt tmp tensor * fix benchmark_llama * fix code style * fix None logic for output tensor * fix adapted to get_xine_cache * add comment * fix ci bugs * fix some codes * rm duplicated codes * rm duplicated codes * fix code style * add _get_dtype in config.py
This commit is contained in:
58
colossalai/inference/flash_decoding_utils.py
Normal file
58
colossalai/inference/flash_decoding_utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class FDIntermTensors(metaclass=SingletonMeta):
|
||||
"""Singleton class to hold tensors used for storing intermediate values in flash-decoding.
|
||||
For now, it holds intermediate output and logsumexp (which will be used in reduction step along kv)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._tensors_initialized = False
|
||||
|
||||
@property
|
||||
def is_initialized(self):
|
||||
return self._tensors_initialized
|
||||
|
||||
@property
|
||||
def mid_output(self):
|
||||
assert self.is_initialized, "Intermediate tensors not initialized yet"
|
||||
return self._mid_output
|
||||
|
||||
@property
|
||||
def mid_output_lse(self):
|
||||
assert self.is_initialized, "Intermediate tensors not initialized yet"
|
||||
return self._mid_output_lse
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
max_batch_size: int,
|
||||
num_attn_heads: int,
|
||||
kv_max_split_num: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: torch.device = get_current_device(),
|
||||
) -> None:
|
||||
"""Initialize tensors.
|
||||
|
||||
Args:
|
||||
max_batch_size (int): The maximum batch size over all the model forward.
|
||||
This could be greater than the batch size in attention forward func when using dynamic batch size.
|
||||
num_attn_heads (int)): Number of attention heads.
|
||||
kv_max_split_num (int): The maximum number of blocks splitted on kv in flash-decoding algorithm.
|
||||
**The maximum length/size of blocks splitted on kv should be the kv cache block size.**
|
||||
head_dim (int): Head dimension.
|
||||
dtype (torch.dtype, optional): Data type to be assigned to intermediate tensors.
|
||||
device (torch.device, optional): Device used to initialize intermediate tensors.
|
||||
"""
|
||||
assert not self.is_initialized, "Intermediate tensors used for Flash-Decoding have been initialized."
|
||||
|
||||
self._mid_output = torch.empty(
|
||||
size=(max_batch_size, num_attn_heads, kv_max_split_num, head_dim), dtype=dtype, device=device
|
||||
)
|
||||
self._mid_output_lse = torch.empty(
|
||||
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
|
||||
)
|
||||
self._tensors_initialized = True
|
Reference in New Issue
Block a user