[Inference]Adapt to baichuan2 13B (#5614)

* adapt to baichuan2 13B

* adapt to baichuan2 13B

* change BAICHUAN_MODEL_NAME_OR_PATH

* fix test_decoding_attn.py

* Modifications based on review comments.

* change BAICHUAN_MODEL_NAME_OR_PATH

* mv attn mask processes to test flash decoding

* mv get_alibi_slopes baichuan modeling

* fix bugs in test_baichuan.py
This commit is contained in:
yuehuayingxueluo
2024-04-25 23:11:30 +08:00
committed by GitHub
parent f342a93871
commit 3c91e3f176
10 changed files with 786 additions and 134 deletions

View File

@@ -124,6 +124,129 @@ def _flash_decoding_fwd_kernel(
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
# Triton 2.1.0
@triton.jit
def _alibi_flash_decoding_fwd_kernel(
Q, # [batch_size * q_len, head_num, head_dim]
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
block_tables, # [batch_size, max_blocks_per_sequence]
mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim]
mid_o_lse, # [batch_size * q_len, head_num, kv_split_num]
kv_seq_len, # [batch_size]
q_len,
batch_size,
alibi_slopes,
stride_qt,
stride_qh,
stride_qd,
stride_cacheb,
stride_cacheh,
stride_cachebs,
stride_cached,
stride_bts,
stride_btb,
stride_mid_ot,
stride_mid_oh,
stride_mid_ob,
stride_mid_od,
stride_mid_o_lset,
stride_mid_o_lseh,
stride_mid_o_lseb,
sm_scale,
KV_GROUPS: tl.constexpr,
BLOCK_KV: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
cur_token_idx = tl.program_id(0)
cur_seq_idx = cur_token_idx // q_len
if cur_seq_idx >= batch_size:
return
cur_token_off = (cur_token_idx % q_len) - q_len + 1
cur_head_idx = tl.program_id(1)
block_start_kv = tl.program_id(2) # for splitting k/v
# NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same
# TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE)
# and then support calculating multiple kv cache blocks on an instance
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
# get the current (kv) sequence length
# cur_token_off is used as a "mask" here for spec-dec during verification process
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
return
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
q = tl.load(Q + offsets_q)
# block table for the current sequence
block_table_ptr = block_tables + cur_seq_idx * stride_bts
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
# cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb)
cur_occupied_size = tl.where(
(block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE
)
tl.device_assert(cur_occupied_size >= 0)
cur_kv_head_idx = cur_head_idx // KV_GROUPS
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
K_block_ptr = tl.make_block_ptr(
base=KCache + offset_kvcache,
shape=(cur_occupied_size, HEAD_DIM),
strides=(stride_cachebs, stride_cached),
offsets=(0, 0),
block_shape=(BLOCK_SIZE, HEAD_DIM),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=VCache + offset_kvcache,
shape=(cur_occupied_size, HEAD_DIM),
strides=(stride_cachebs, stride_cached),
offsets=(0, 0),
block_shape=(BLOCK_SIZE, HEAD_DIM),
order=(0, 1),
)
k_cur_block = tl.load(K_block_ptr)
v_cur_block = tl.load(V_block_ptr)
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
# use block size of the paged/blocked kv cache
S_ij = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
alibi_slope = tl.load(alibi_slopes + cur_head_idx)
position_k_offset = block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE)
# NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16,
# Multiplying two tensors with shapes [1, d] * [d, block_size] will fail.
# Refer to https://github.com/openai/triton/discussions/895
S_ij += tl.sum(q[None, :] * k_cur_block, 1)
S_ij *= sm_scale
S_ij -= alibi_slope * (cur_kv_seq_len - 1 - position_k_offset)
S_ij = tl.where(cur_kv_seq_len > position_k_offset, S_ij, float("-inf"))
m = tl.max(S_ij, 0)
S_ij -= m
p_ij_hat = tl.exp(S_ij)
l = tl.sum(p_ij_hat, 0)
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
acc = acc / l
offsets_mid_o = (
cur_token_idx * stride_mid_ot
+ cur_head_idx * stride_mid_oh
+ block_start_kv * stride_mid_ob
+ offsets_dmodel * stride_mid_od
)
tl.store(mid_o + offsets_mid_o, acc)
offsets_mid_o_lse = (
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
)
# logsumexp L^(j) = m^(j) + log(l^(j))
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
# Triton 2.1.0
@triton.jit
def _flash_decoding_fwd_reduce_kernel(
@@ -197,9 +320,10 @@ def flash_decoding_attention(
output: torch.Tensor = None,
mid_output: torch.Tensor = None,
mid_output_lse: torch.Tensor = None,
alibi_slopes: torch.Tensor = None,
sm_scale: int = None,
kv_group_num: int = 1,
q_len: int = 1,
q_len: int = 1, # NOTE alibi flash decoding does not support q_len > 1 at this moment.
):
"""
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
@@ -220,6 +344,7 @@ def flash_decoding_attention(
mid_output_lse (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num]
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
q_len > 1 only for verification process in speculative-decoding.
alibi_slopes (torch.Tensor): [num_heads] alibi slopes used for alibi flash decoding.
block_size (int): Size of each block in the blocked key/value cache.
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
@@ -280,38 +405,74 @@ def flash_decoding_attention(
num_heads,
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV),
)
_flash_decoding_fwd_kernel[grid](
q,
k_cache,
v_cache,
block_tables,
mid_output,
mid_output_lse,
kv_seq_len,
q_len,
bsz,
q.stride(0),
q.stride(1),
q.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
mid_output.stride(0),
mid_output.stride(1),
mid_output.stride(2),
mid_output.stride(3),
mid_output_lse.stride(0),
mid_output_lse.stride(1),
mid_output_lse.stride(2),
sm_scale,
KV_GROUPS=kv_group_num,
BLOCK_KV=block_size,
BLOCK_SIZE=block_size,
HEAD_DIM=head_dim,
)
if alibi_slopes is not None:
_alibi_flash_decoding_fwd_kernel[grid](
q,
k_cache,
v_cache,
block_tables,
mid_output,
mid_output_lse,
kv_seq_len,
q_len,
bsz,
alibi_slopes,
q.stride(0),
q.stride(1),
q.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
mid_output.stride(0),
mid_output.stride(1),
mid_output.stride(2),
mid_output.stride(3),
mid_output_lse.stride(0),
mid_output_lse.stride(1),
mid_output_lse.stride(2),
sm_scale,
KV_GROUPS=kv_group_num,
BLOCK_KV=block_size,
BLOCK_SIZE=block_size,
HEAD_DIM=head_dim,
)
else:
_flash_decoding_fwd_kernel[grid](
q,
k_cache,
v_cache,
block_tables,
mid_output,
mid_output_lse,
kv_seq_len,
q_len,
bsz,
q.stride(0),
q.stride(1),
q.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
mid_output.stride(0),
mid_output.stride(1),
mid_output.stride(2),
mid_output.stride(3),
mid_output_lse.stride(0),
mid_output_lse.stride(1),
mid_output_lse.stride(2),
sm_scale,
KV_GROUPS=kv_group_num,
BLOCK_KV=block_size,
BLOCK_SIZE=block_size,
HEAD_DIM=head_dim,
)
grid = (triton.next_power_of_2(bsz * q_len), num_heads)
_flash_decoding_fwd_reduce_kernel[grid](