mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn * opt tmp tensor * fix benchmark_llama * fix code style * fix None logic for output tensor * fix adapted to get_xine_cache * add comment * fix ci bugs * fix some codes * rm duplicated codes * rm duplicated codes * fix code style * add _get_dtype in config.py
This commit is contained in:
@@ -10,7 +10,6 @@ except ImportError:
|
||||
if HAS_TRITON:
|
||||
from .context_attn_unpad import context_attention_unpadded
|
||||
from .flash_decoding import flash_decoding_attention
|
||||
from .flash_decoding_utils import FDIntermTensors
|
||||
from .fused_rotary_embedding import fused_rotary_embedding
|
||||
from .gptq_triton import gptq_fused_linear_triton
|
||||
from .kvcache_copy import copy_kv_to_blocked_cache
|
||||
@@ -27,7 +26,6 @@ if HAS_TRITON:
|
||||
"rms_layernorm",
|
||||
"gptq_fused_linear_triton",
|
||||
"rotary_embedding",
|
||||
"FDIntermTensors",
|
||||
"fused_rotary_embedding",
|
||||
"get_xine_cache",
|
||||
]
|
||||
|
@@ -5,7 +5,6 @@
|
||||
#
|
||||
# Inspired and modified from Triton Tutorial - Fused Attention
|
||||
# https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -195,7 +194,9 @@ def context_attention_unpadded(
|
||||
context_lengths: torch.Tensor, # [num_seqs]
|
||||
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
|
||||
block_size: int,
|
||||
max_seq_len_in_b: Optional[int] = None,
|
||||
output: torch.Tensor = None, # [num_tokens, num_heads, head_dim]
|
||||
max_seq_len: int = None,
|
||||
sm_scale: int = None,
|
||||
):
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk == Lv
|
||||
@@ -210,10 +211,9 @@ def context_attention_unpadded(
|
||||
num_kv_group = num_heads // num_kv_heads
|
||||
|
||||
num_seqs, max_blocks_per_seq = block_tables.shape
|
||||
max_seq_len = context_lengths.max().item() if max_seq_len_in_b is None else max_seq_len_in_b
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
|
||||
output = torch.zeros_like(q)
|
||||
max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len
|
||||
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
|
||||
output = torch.zeros_like(q) if output is None else output
|
||||
|
||||
# NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with
|
||||
# the size of physical cache block (i.e. `block_size`)
|
||||
|
@@ -195,6 +195,7 @@ def flash_decoding_attention(
|
||||
block_tables: torch.Tensor,
|
||||
block_size: int,
|
||||
max_seq_len_in_batch: int = None,
|
||||
output: torch.Tensor = None,
|
||||
mid_output: torch.Tensor = None,
|
||||
mid_output_lse: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
@@ -211,6 +212,7 @@ def flash_decoding_attention(
|
||||
records the (kv) sequence lengths incorporating past kv sequence lengths.
|
||||
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
||||
max_seq_len_in_batch (int): Maximum sequence length in the batch.
|
||||
output (torch.Tensor): [bsz, 1, num_heads, head_dim]
|
||||
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
|
||||
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
|
||||
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
|
||||
@@ -292,7 +294,7 @@ def flash_decoding_attention(
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
|
||||
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) # already overlapped
|
||||
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
||||
|
||||
grid = (triton.next_power_of_2(bsz), num_heads)
|
||||
|
||||
|
@@ -1,58 +0,0 @@
|
||||
import torch
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class FDIntermTensors(metaclass=SingletonMeta):
|
||||
"""Singleton class to hold tensors used for storing intermediate values in flash-decoding.
|
||||
For now, it holds intermediate output and logsumexp (which will be used in reduction step along kv)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._tensors_initialized = False
|
||||
|
||||
@property
|
||||
def is_initialized(self):
|
||||
return self._tensors_initialized
|
||||
|
||||
@property
|
||||
def mid_output(self):
|
||||
assert self.is_initialized, "Intermediate tensors not initialized yet"
|
||||
return self._mid_output
|
||||
|
||||
@property
|
||||
def mid_output_lse(self):
|
||||
assert self.is_initialized, "Intermediate tensors not initialized yet"
|
||||
return self._mid_output_lse
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
max_batch_size: int,
|
||||
num_attn_heads: int,
|
||||
kv_max_split_num: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: torch.device = get_current_device(),
|
||||
) -> None:
|
||||
"""Initialize tensors.
|
||||
|
||||
Args:
|
||||
max_batch_size (int): The maximum batch size over all the model forward.
|
||||
This could be greater than the batch size in attention forward func when using dynamic batch size.
|
||||
num_attn_heads (int)): Number of attention heads.
|
||||
kv_max_split_num (int): The maximum number of blocks splitted on kv in flash-decoding algorithm.
|
||||
**The maximum length/size of blocks splitted on kv should be the kv cache block size.**
|
||||
head_dim (int): Head dimension.
|
||||
dtype (torch.dtype, optional): Data type to be assigned to intermediate tensors.
|
||||
device (torch.device, optional): Device used to initialize intermediate tensors.
|
||||
"""
|
||||
assert not self.is_initialized, "Intermediate tensors used for Flash-Decoding have been initialized."
|
||||
|
||||
self._mid_output = torch.empty(
|
||||
size=(max_batch_size, num_attn_heads, kv_max_split_num, head_dim), dtype=dtype, device=device
|
||||
)
|
||||
self._mid_output_lse = torch.empty(
|
||||
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
|
||||
)
|
||||
self._tensors_initialized = True
|
Reference in New Issue
Block a user