mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 23:11:55 +00:00
[NFC] Fix code factors on inference triton kernels (#5743)
This commit is contained in:
parent
c2c8c9cf17
commit
bd38fe6b91
@ -111,10 +111,10 @@ def _flash_decoding_fwd_kernel(
|
|||||||
m = tl.max(S_ij, 0)
|
m = tl.max(S_ij, 0)
|
||||||
S_ij -= m
|
S_ij -= m
|
||||||
p_ij_hat = tl.exp(S_ij)
|
p_ij_hat = tl.exp(S_ij)
|
||||||
l = tl.sum(p_ij_hat, 0)
|
l_i = tl.sum(p_ij_hat, 0)
|
||||||
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
|
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
|
||||||
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
|
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
|
||||||
acc = acc / l
|
acc = acc / l_i
|
||||||
|
|
||||||
offsets_mid_o = (
|
offsets_mid_o = (
|
||||||
cur_token_idx * stride_mid_ot
|
cur_token_idx * stride_mid_ot
|
||||||
@ -126,8 +126,8 @@ def _flash_decoding_fwd_kernel(
|
|||||||
offsets_mid_o_lse = (
|
offsets_mid_o_lse = (
|
||||||
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
|
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
|
||||||
)
|
)
|
||||||
# logsumexp L^(j) = m^(j) + log(l^(j))
|
# logsumexp l_i^(j) = m^(j) + log(l_i^(j))
|
||||||
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
|
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l_i))
|
||||||
|
|
||||||
|
|
||||||
# Triton 2.1.0
|
# Triton 2.1.0
|
||||||
@ -234,10 +234,10 @@ def _alibi_flash_decoding_fwd_kernel(
|
|||||||
m = tl.max(S_ij, 0)
|
m = tl.max(S_ij, 0)
|
||||||
S_ij -= m
|
S_ij -= m
|
||||||
p_ij_hat = tl.exp(S_ij)
|
p_ij_hat = tl.exp(S_ij)
|
||||||
l = tl.sum(p_ij_hat, 0)
|
l_i = tl.sum(p_ij_hat, 0)
|
||||||
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
|
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
|
||||||
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
|
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
|
||||||
acc = acc / l
|
acc = acc / l_i
|
||||||
|
|
||||||
offsets_mid_o = (
|
offsets_mid_o = (
|
||||||
cur_token_idx * stride_mid_ot
|
cur_token_idx * stride_mid_ot
|
||||||
@ -249,8 +249,8 @@ def _alibi_flash_decoding_fwd_kernel(
|
|||||||
offsets_mid_o_lse = (
|
offsets_mid_o_lse = (
|
||||||
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
|
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
|
||||||
)
|
)
|
||||||
# logsumexp L^(j) = m^(j) + log(l^(j))
|
# logsumexp l_i^(j) = m^(j) + log(l_i^(j))
|
||||||
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
|
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l_i))
|
||||||
|
|
||||||
|
|
||||||
# Triton 2.1.0
|
# Triton 2.1.0
|
||||||
@ -290,7 +290,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
|||||||
# BLOCK_KV == BLOCK_SIZE for now. We might want to decrease the number of blocks of kv splitted.
|
# BLOCK_KV == BLOCK_SIZE for now. We might want to decrease the number of blocks of kv splitted.
|
||||||
kv_split_num = (cur_kv_seq_len + BLOCK_KV - 1) // BLOCK_KV
|
kv_split_num = (cur_kv_seq_len + BLOCK_KV - 1) // BLOCK_KV
|
||||||
m_i = float("-inf") # max logic
|
m_i = float("-inf") # max logic
|
||||||
l = 0.0 # sum exp
|
l_i = 0.0 # sum exp
|
||||||
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
||||||
|
|
||||||
offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
|
offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
|
||||||
@ -304,10 +304,10 @@ def _flash_decoding_fwd_reduce_kernel(
|
|||||||
lse -= m_ij
|
lse -= m_ij
|
||||||
exp_logic = tl.exp(lse)
|
exp_logic = tl.exp(lse)
|
||||||
acc += exp_logic * mid_o_block
|
acc += exp_logic * mid_o_block
|
||||||
l = scale * l + exp_logic
|
l_i = scale * l_i + exp_logic
|
||||||
m_i = m_ij
|
m_i = m_ij
|
||||||
|
|
||||||
acc = acc / l
|
acc = acc / l_i
|
||||||
offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
|
offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
|
||||||
tl.store(O + offsets_O, acc.to(O.type.element_ty))
|
tl.store(O + offsets_O, acc.to(O.type.element_ty))
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user