mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[Inference]Adapt to baichuan2 13B (#5614)
* adapt to baichuan2 13B * adapt to baichuan2 13B * change BAICHUAN_MODEL_NAME_OR_PATH * fix test_decoding_attn.py * Modifications based on review comments. * change BAICHUAN_MODEL_NAME_OR_PATH * mv attn mask processes to test flash decoding * mv get_alibi_slopes baichuan modeling * fix bugs in test_baichuan.py
This commit is contained in:
@@ -185,6 +185,192 @@ def _fwd_context_paged_attention_kernel(
|
||||
return
|
||||
|
||||
|
||||
# Triton 2.1.0
|
||||
@triton.jit
|
||||
def _alibi_fwd_context_paged_attention_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
O,
|
||||
KCache,
|
||||
VCache,
|
||||
BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence]
|
||||
batch_size,
|
||||
alibi_slopes,
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kt,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vt,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_ot,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
stride_cachebs,
|
||||
stride_cached,
|
||||
stride_bts,
|
||||
stride_btb,
|
||||
context_lengths,
|
||||
sm_scale,
|
||||
KV_GROUPS: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_seq_idx = tl.program_id(0)
|
||||
if cur_seq_idx >= batch_size:
|
||||
return
|
||||
cur_head_idx = tl.program_id(1)
|
||||
block_start_m = tl.program_id(2) # Br, max_input_len // Block_M
|
||||
cur_kv_head_idx = cur_head_idx // KV_GROUPS
|
||||
|
||||
global_block_start_offest = block_start_m * BLOCK_M
|
||||
|
||||
# NOTE It requires BLOCK_M, BLOCK_N, and BLOCK_SIZE to be the same
|
||||
tl.static_assert(BLOCK_M == BLOCK_N)
|
||||
tl.static_assert(BLOCK_N == BLOCK_SIZE)
|
||||
|
||||
# get the current sequence length from provided context lengths tensor
|
||||
cur_seq_len = tl.load(context_lengths + cur_seq_idx)
|
||||
# NOTE when talking to fused QKV and a nopadding context attention,
|
||||
# we assume that the input Q/K/V is contiguous, and thus here `prev_seq_len_sum`
|
||||
# could be considered as the start index of the current sequence.
|
||||
# FIXME might want to explore better way to get the summation of prev seq lengths.
|
||||
# `tl.sum(tensor[:end])` is invalid as tensor slice is not supported in triton.
|
||||
prev_seq_len_sum = 0
|
||||
for i in range(0, cur_seq_idx):
|
||||
prev_seq_len_sum += tl.load(context_lengths + i)
|
||||
|
||||
offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh
|
||||
offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + offset_q,
|
||||
shape=(cur_seq_len, HEAD_DIM),
|
||||
strides=(stride_qt, stride_qd),
|
||||
offsets=(global_block_start_offest, 0),
|
||||
block_shape=(BLOCK_M, HEAD_DIM),
|
||||
order=(1, 0),
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + offset_kv,
|
||||
shape=(HEAD_DIM, cur_seq_len),
|
||||
strides=(stride_kd, stride_kt),
|
||||
offsets=(0, 0),
|
||||
block_shape=(HEAD_DIM, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + offset_kv,
|
||||
shape=(cur_seq_len, HEAD_DIM),
|
||||
strides=(stride_vt, stride_vd),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, HEAD_DIM),
|
||||
order=(1, 0),
|
||||
)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=O + offset_q,
|
||||
shape=(cur_seq_len, HEAD_DIM),
|
||||
strides=(stride_ot, stride_od),
|
||||
offsets=(global_block_start_offest, 0),
|
||||
block_shape=(BLOCK_M, HEAD_DIM),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
# block table for the current sequence
|
||||
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
|
||||
# block indexes on block table (i.e. 0, 1, 2, ..., max_blocks_per_seq)
|
||||
# Consider `block_start_m` as the logical block idx in the current block table,
|
||||
# as we have BLOCK_M the same size as the block size.
|
||||
cur_block_table_idx = block_start_m
|
||||
cur_block_id = tl.load(block_table_ptr + cur_block_table_idx * stride_btb)
|
||||
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
|
||||
|
||||
offsets_m = global_block_start_offest + tl.arange(0, BLOCK_M)
|
||||
offsets_n = tl.arange(0, BLOCK_N)
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
||||
|
||||
# load alibi_slope
|
||||
alibi_slope = tl.load(alibi_slopes + cur_head_idx)
|
||||
m_alibi_offset = tl.arange(0, BLOCK_M)[:, None] + global_block_start_offest
|
||||
n_alibi_offset = tl.arange(0, BLOCK_N)[None, :]
|
||||
|
||||
if global_block_start_offest >= cur_seq_len:
|
||||
return
|
||||
|
||||
Q_i = tl.load(Q_block_ptr, boundary_check=(1, 0))
|
||||
|
||||
for block_start_n in range(0, (block_start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
block_start_n = tl.multiple_of(block_start_n, BLOCK_N)
|
||||
|
||||
k = tl.load(K_block_ptr, boundary_check=(0, 1))
|
||||
S_ij = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
S_ij += tl.dot(Q_i, k)
|
||||
S_ij *= sm_scale
|
||||
S_ij += tl.where(offsets_m[:, None] >= (block_start_n + offsets_n[None, :]), 0, float("-inf"))
|
||||
|
||||
alibi = (n_alibi_offset + block_start_n - m_alibi_offset) * alibi_slope
|
||||
alibi = tl.where((alibi <= 0) & (m_alibi_offset < cur_seq_len), alibi, float("-inf"))
|
||||
S_ij += alibi
|
||||
|
||||
m_ij = tl.max(S_ij, 1) # rowmax(Sij)
|
||||
m_ij = tl.maximum(m_i, m_ij) # m_ij
|
||||
S_ij -= m_ij[:, None]
|
||||
p_ij_hat = tl.exp(S_ij)
|
||||
scale = tl.exp(m_i - m_ij)
|
||||
l_ij = scale * l_i + tl.sum(p_ij_hat, 1)
|
||||
acc = acc * scale[:, None]
|
||||
|
||||
v = tl.load(V_block_ptr, boundary_check=(1, 0))
|
||||
p_ij_hat = p_ij_hat.to(v.type.element_ty)
|
||||
|
||||
acc += tl.dot(p_ij_hat, v)
|
||||
l_i = l_ij
|
||||
m_i = m_ij
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
tl.store(O_block_ptr, acc.to(O.type.element_ty), boundary_check=(1, 0))
|
||||
|
||||
if cur_head_idx % KV_GROUPS == 0:
|
||||
# Copy k to corresponding cache block
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_kt = global_block_start_offest + tl.arange(0, BLOCK_M)
|
||||
offsets_k = K + offset_kv + offsets_dmodel[None, :] * stride_kd + offsets_kt[:, None] * stride_kt
|
||||
k = tl.load(offsets_k, mask=offsets_kt[:, None] < cur_seq_len, other=0.0)
|
||||
offsets_kcachebs = tl.arange(0, BLOCK_SIZE)
|
||||
offsets_kcache = (
|
||||
KCache
|
||||
+ offset_kvcache
|
||||
+ offsets_dmodel[None, :] * stride_cached
|
||||
+ offsets_kcachebs[:, None] * stride_cachebs
|
||||
)
|
||||
tl.store(offsets_kcache, k, mask=offsets_kcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
# Copy v to corresponding cache block
|
||||
offsets_vd = offsets_dmodel
|
||||
offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd
|
||||
v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0)
|
||||
offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
|
||||
offsets_vcache = (
|
||||
VCache
|
||||
+ offset_kvcache
|
||||
+ offsets_vcachebs[None, :] * stride_cachebs
|
||||
+ offsets_dmodel[:, None] * stride_cached
|
||||
)
|
||||
tl.store(offsets_vcache, v, mask=offsets_vcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def context_attention_unpadded(
|
||||
q: torch.Tensor, # [num_tokens, num_heads, head_dim]
|
||||
k: torch.Tensor, # [num_tokens, num_kv_heads, head_dim]
|
||||
@@ -195,6 +381,7 @@ def context_attention_unpadded(
|
||||
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
|
||||
block_size: int,
|
||||
output: torch.Tensor = None, # [num_tokens, num_heads, head_dim]
|
||||
alibi_slopes: torch.Tensor = None, # [num_heads]
|
||||
max_seq_len: int = None,
|
||||
sm_scale: int = None,
|
||||
):
|
||||
@@ -226,40 +413,78 @@ def context_attention_unpadded(
|
||||
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
|
||||
grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M))
|
||||
|
||||
_fwd_context_paged_attention_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
num_seqs,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
output.stride(0),
|
||||
head_dim,
|
||||
1,
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
context_lengths,
|
||||
sm_scale,
|
||||
num_kv_group,
|
||||
block_size,
|
||||
HEAD_DIM=Lk,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
)
|
||||
if alibi_slopes is not None:
|
||||
_alibi_fwd_context_paged_attention_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
num_seqs,
|
||||
alibi_slopes,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
output.stride(0),
|
||||
head_dim,
|
||||
1,
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
context_lengths,
|
||||
sm_scale,
|
||||
num_kv_group,
|
||||
block_size,
|
||||
HEAD_DIM=Lk,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
)
|
||||
else:
|
||||
_fwd_context_paged_attention_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
num_seqs,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
output.stride(0),
|
||||
head_dim,
|
||||
1,
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
context_lengths,
|
||||
sm_scale,
|
||||
num_kv_group,
|
||||
block_size,
|
||||
HEAD_DIM=Lk,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
)
|
||||
|
||||
return output
|
||||
|
@@ -124,6 +124,129 @@ def _flash_decoding_fwd_kernel(
|
||||
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
|
||||
|
||||
|
||||
# Triton 2.1.0
|
||||
@triton.jit
|
||||
def _alibi_flash_decoding_fwd_kernel(
|
||||
Q, # [batch_size * q_len, head_num, head_dim]
|
||||
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||
block_tables, # [batch_size, max_blocks_per_sequence]
|
||||
mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim]
|
||||
mid_o_lse, # [batch_size * q_len, head_num, kv_split_num]
|
||||
kv_seq_len, # [batch_size]
|
||||
q_len,
|
||||
batch_size,
|
||||
alibi_slopes,
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
stride_cachebs,
|
||||
stride_cached,
|
||||
stride_bts,
|
||||
stride_btb,
|
||||
stride_mid_ot,
|
||||
stride_mid_oh,
|
||||
stride_mid_ob,
|
||||
stride_mid_od,
|
||||
stride_mid_o_lset,
|
||||
stride_mid_o_lseh,
|
||||
stride_mid_o_lseb,
|
||||
sm_scale,
|
||||
KV_GROUPS: tl.constexpr,
|
||||
BLOCK_KV: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
):
|
||||
cur_token_idx = tl.program_id(0)
|
||||
cur_seq_idx = cur_token_idx // q_len
|
||||
if cur_seq_idx >= batch_size:
|
||||
return
|
||||
cur_token_off = (cur_token_idx % q_len) - q_len + 1
|
||||
cur_head_idx = tl.program_id(1)
|
||||
block_start_kv = tl.program_id(2) # for splitting k/v
|
||||
|
||||
# NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same
|
||||
# TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE)
|
||||
# and then support calculating multiple kv cache blocks on an instance
|
||||
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
|
||||
# get the current (kv) sequence length
|
||||
# cur_token_off is used as a "mask" here for spec-dec during verification process
|
||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
|
||||
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
|
||||
return
|
||||
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
|
||||
q = tl.load(Q + offsets_q)
|
||||
# block table for the current sequence
|
||||
block_table_ptr = block_tables + cur_seq_idx * stride_bts
|
||||
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
|
||||
# cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
|
||||
cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb)
|
||||
cur_occupied_size = tl.where(
|
||||
(block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE
|
||||
)
|
||||
tl.device_assert(cur_occupied_size >= 0)
|
||||
|
||||
cur_kv_head_idx = cur_head_idx // KV_GROUPS
|
||||
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=KCache + offset_kvcache,
|
||||
shape=(cur_occupied_size, HEAD_DIM),
|
||||
strides=(stride_cachebs, stride_cached),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_SIZE, HEAD_DIM),
|
||||
order=(0, 1),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=VCache + offset_kvcache,
|
||||
shape=(cur_occupied_size, HEAD_DIM),
|
||||
strides=(stride_cachebs, stride_cached),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_SIZE, HEAD_DIM),
|
||||
order=(0, 1),
|
||||
)
|
||||
k_cur_block = tl.load(K_block_ptr)
|
||||
v_cur_block = tl.load(V_block_ptr)
|
||||
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
||||
# use block size of the paged/blocked kv cache
|
||||
S_ij = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
|
||||
alibi_slope = tl.load(alibi_slopes + cur_head_idx)
|
||||
position_k_offset = block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
# NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16,
|
||||
# Multiplying two tensors with shapes [1, d] * [d, block_size] will fail.
|
||||
# Refer to https://github.com/openai/triton/discussions/895
|
||||
S_ij += tl.sum(q[None, :] * k_cur_block, 1)
|
||||
S_ij *= sm_scale
|
||||
S_ij -= alibi_slope * (cur_kv_seq_len - 1 - position_k_offset)
|
||||
S_ij = tl.where(cur_kv_seq_len > position_k_offset, S_ij, float("-inf"))
|
||||
|
||||
m = tl.max(S_ij, 0)
|
||||
S_ij -= m
|
||||
p_ij_hat = tl.exp(S_ij)
|
||||
l = tl.sum(p_ij_hat, 0)
|
||||
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
|
||||
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
|
||||
acc = acc / l
|
||||
|
||||
offsets_mid_o = (
|
||||
cur_token_idx * stride_mid_ot
|
||||
+ cur_head_idx * stride_mid_oh
|
||||
+ block_start_kv * stride_mid_ob
|
||||
+ offsets_dmodel * stride_mid_od
|
||||
)
|
||||
tl.store(mid_o + offsets_mid_o, acc)
|
||||
offsets_mid_o_lse = (
|
||||
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
|
||||
)
|
||||
# logsumexp L^(j) = m^(j) + log(l^(j))
|
||||
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
|
||||
|
||||
|
||||
# Triton 2.1.0
|
||||
@triton.jit
|
||||
def _flash_decoding_fwd_reduce_kernel(
|
||||
@@ -197,9 +320,10 @@ def flash_decoding_attention(
|
||||
output: torch.Tensor = None,
|
||||
mid_output: torch.Tensor = None,
|
||||
mid_output_lse: torch.Tensor = None,
|
||||
alibi_slopes: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
kv_group_num: int = 1,
|
||||
q_len: int = 1,
|
||||
q_len: int = 1, # NOTE alibi flash decoding does not support q_len > 1 at this moment.
|
||||
):
|
||||
"""
|
||||
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
|
||||
@@ -220,6 +344,7 @@ def flash_decoding_attention(
|
||||
mid_output_lse (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num]
|
||||
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
|
||||
q_len > 1 only for verification process in speculative-decoding.
|
||||
alibi_slopes (torch.Tensor): [num_heads] alibi slopes used for alibi flash decoding.
|
||||
block_size (int): Size of each block in the blocked key/value cache.
|
||||
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
|
||||
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
|
||||
@@ -280,38 +405,74 @@ def flash_decoding_attention(
|
||||
num_heads,
|
||||
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV),
|
||||
)
|
||||
_flash_decoding_fwd_kernel[grid](
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
kv_seq_len,
|
||||
q_len,
|
||||
bsz,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
mid_output.stride(0),
|
||||
mid_output.stride(1),
|
||||
mid_output.stride(2),
|
||||
mid_output.stride(3),
|
||||
mid_output_lse.stride(0),
|
||||
mid_output_lse.stride(1),
|
||||
mid_output_lse.stride(2),
|
||||
sm_scale,
|
||||
KV_GROUPS=kv_group_num,
|
||||
BLOCK_KV=block_size,
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
|
||||
if alibi_slopes is not None:
|
||||
_alibi_flash_decoding_fwd_kernel[grid](
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
kv_seq_len,
|
||||
q_len,
|
||||
bsz,
|
||||
alibi_slopes,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
mid_output.stride(0),
|
||||
mid_output.stride(1),
|
||||
mid_output.stride(2),
|
||||
mid_output.stride(3),
|
||||
mid_output_lse.stride(0),
|
||||
mid_output_lse.stride(1),
|
||||
mid_output_lse.stride(2),
|
||||
sm_scale,
|
||||
KV_GROUPS=kv_group_num,
|
||||
BLOCK_KV=block_size,
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
else:
|
||||
_flash_decoding_fwd_kernel[grid](
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
kv_seq_len,
|
||||
q_len,
|
||||
bsz,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
mid_output.stride(0),
|
||||
mid_output.stride(1),
|
||||
mid_output.stride(2),
|
||||
mid_output.stride(3),
|
||||
mid_output_lse.stride(0),
|
||||
mid_output_lse.stride(1),
|
||||
mid_output_lse.stride(2),
|
||||
sm_scale,
|
||||
KV_GROUPS=kv_group_num,
|
||||
BLOCK_KV=block_size,
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
|
||||
grid = (triton.next_power_of_2(bsz * q_len), num_heads)
|
||||
_flash_decoding_fwd_reduce_kernel[grid](
|
||||
|
Reference in New Issue
Block a user