mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn * opt tmp tensor * fix benchmark_llama * fix code style * fix None logic for output tensor * fix adapted to get_xine_cache * add comment * fix ci bugs * fix some codes * rm duplicated codes * rm duplicated codes * fix code style * add _get_dtype in config.py
This commit is contained in:
@@ -5,7 +5,6 @@
|
||||
#
|
||||
# Inspired and modified from Triton Tutorial - Fused Attention
|
||||
# https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -195,7 +194,9 @@ def context_attention_unpadded(
|
||||
context_lengths: torch.Tensor, # [num_seqs]
|
||||
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
|
||||
block_size: int,
|
||||
max_seq_len_in_b: Optional[int] = None,
|
||||
output: torch.Tensor = None, # [num_tokens, num_heads, head_dim]
|
||||
max_seq_len: int = None,
|
||||
sm_scale: int = None,
|
||||
):
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk == Lv
|
||||
@@ -210,10 +211,9 @@ def context_attention_unpadded(
|
||||
num_kv_group = num_heads // num_kv_heads
|
||||
|
||||
num_seqs, max_blocks_per_seq = block_tables.shape
|
||||
max_seq_len = context_lengths.max().item() if max_seq_len_in_b is None else max_seq_len_in_b
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
|
||||
output = torch.zeros_like(q)
|
||||
max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len
|
||||
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
|
||||
output = torch.zeros_like(q) if output is None else output
|
||||
|
||||
# NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with
|
||||
# the size of physical cache block (i.e. `block_size`)
|
||||
|
Reference in New Issue
Block a user