[inference]Optimize the usage of the mid tensors space in flash attn (#5304)

* opt flash attn

* opt tmp tensor

* fix benchmark_llama

* fix code style

* fix None logic for output tensor

* fix adapted to get_xine_cache

* add comment

* fix ci bugs

* fix some codes

* rm duplicated codes

* rm duplicated codes

* fix code style

* add _get_dtype in config.py
This commit is contained in:
yuehuayingxueluo
2024-01-26 14:00:10 +08:00
committed by GitHub
parent af8359c430
commit 4f28cb43c0
16 changed files with 199 additions and 57 deletions

View File

@@ -10,7 +10,6 @@ except ImportError:
if HAS_TRITON:
from .context_attn_unpad import context_attention_unpadded
from .flash_decoding import flash_decoding_attention
from .flash_decoding_utils import FDIntermTensors
from .fused_rotary_embedding import fused_rotary_embedding
from .gptq_triton import gptq_fused_linear_triton
from .kvcache_copy import copy_kv_to_blocked_cache
@@ -27,7 +26,6 @@ if HAS_TRITON:
"rms_layernorm",
"gptq_fused_linear_triton",
"rotary_embedding",
"FDIntermTensors",
"fused_rotary_embedding",
"get_xine_cache",
]

View File

@@ -5,7 +5,6 @@
#
# Inspired and modified from Triton Tutorial - Fused Attention
# https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
from typing import Optional
import torch
import triton
@@ -195,7 +194,9 @@ def context_attention_unpadded(
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
block_size: int,
max_seq_len_in_b: Optional[int] = None,
output: torch.Tensor = None, # [num_tokens, num_heads, head_dim]
max_seq_len: int = None,
sm_scale: int = None,
):
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk == Lv
@@ -210,10 +211,9 @@ def context_attention_unpadded(
num_kv_group = num_heads // num_kv_heads
num_seqs, max_blocks_per_seq = block_tables.shape
max_seq_len = context_lengths.max().item() if max_seq_len_in_b is None else max_seq_len_in_b
sm_scale = 1.0 / (Lq**0.5)
output = torch.zeros_like(q)
max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
output = torch.zeros_like(q) if output is None else output
# NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with
# the size of physical cache block (i.e. `block_size`)

View File

@@ -195,6 +195,7 @@ def flash_decoding_attention(
block_tables: torch.Tensor,
block_size: int,
max_seq_len_in_batch: int = None,
output: torch.Tensor = None,
mid_output: torch.Tensor = None,
mid_output_lse: torch.Tensor = None,
sm_scale: int = None,
@@ -211,6 +212,7 @@ def flash_decoding_attention(
records the (kv) sequence lengths incorporating past kv sequence lengths.
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
max_seq_len_in_batch (int): Maximum sequence length in the batch.
output (torch.Tensor): [bsz, 1, num_heads, head_dim]
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
@@ -292,7 +294,7 @@ def flash_decoding_attention(
HEAD_DIM=head_dim,
)
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) # already overlapped
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
grid = (triton.next_power_of_2(bsz), num_heads)

View File

@@ -1,58 +0,0 @@
import torch
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.utils import get_current_device
class FDIntermTensors(metaclass=SingletonMeta):
"""Singleton class to hold tensors used for storing intermediate values in flash-decoding.
For now, it holds intermediate output and logsumexp (which will be used in reduction step along kv)
"""
def __init__(self):
self._tensors_initialized = False
@property
def is_initialized(self):
return self._tensors_initialized
@property
def mid_output(self):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._mid_output
@property
def mid_output_lse(self):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._mid_output_lse
def initialize(
self,
max_batch_size: int,
num_attn_heads: int,
kv_max_split_num: int,
head_dim: int,
dtype: torch.dtype = torch.float32,
device: torch.device = get_current_device(),
) -> None:
"""Initialize tensors.
Args:
max_batch_size (int): The maximum batch size over all the model forward.
This could be greater than the batch size in attention forward func when using dynamic batch size.
num_attn_heads (int)): Number of attention heads.
kv_max_split_num (int): The maximum number of blocks splitted on kv in flash-decoding algorithm.
**The maximum length/size of blocks splitted on kv should be the kv cache block size.**
head_dim (int): Head dimension.
dtype (torch.dtype, optional): Data type to be assigned to intermediate tensors.
device (torch.device, optional): Device used to initialize intermediate tensors.
"""
assert not self.is_initialized, "Intermediate tensors used for Flash-Decoding have been initialized."
self._mid_output = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num, head_dim), dtype=dtype, device=device
)
self._mid_output_lse = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
)
self._tensors_initialized = True