[inference] add reference and fix some bugs (#4937)

* add reference and fix some bugs

* update gptq init

---------

Co-authored-by: Xu Kai <xukai16@foxamil.com>
This commit is contained in:
Xu Kai
2023-10-20 13:39:34 +08:00
committed by GitHub
parent b8e770c832
commit 785802e809
7 changed files with 24 additions and 10 deletions

View File

@@ -267,6 +267,7 @@ def cai_gptq_matmul_248_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask)
# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
@autotune(
configs=[
triton.Config(

View File

@@ -13,10 +13,10 @@ except ImportError:
if HAS_TRITON:
"""
this function is modified from
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
this functions are modified from https://github.com/ModelTC/lightllm
"""
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
@triton.jit
def _context_flash_attention_kernel(
Q,
@@ -145,20 +145,16 @@ if HAS_TRITON:
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
@torch.no_grad()
def smooth_llama_context_attn_fwd(
q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len
):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk, "context process only supports equal query, key, value length"
assert Lk == Lv, "context process only supports equal query, key, value length"
assert Lk in {16, 32, 64, 128}
BLOCK_N = 128
sm_scale = 1.0 / math.sqrt(Lk)
batch, head = b_seq_len.shape[0], q.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
@@ -203,6 +199,7 @@ if HAS_TRITON:
)
return
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@triton.jit
def _token_attn_1_kernel(
Q,
@@ -264,6 +261,7 @@ if HAS_TRITON:
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
return
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@triton.jit
def _token_attn_1_alibi_kernel(
Q,
@@ -413,6 +411,7 @@ if HAS_TRITON:
)
return
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
@triton.jit
def _token_attn_softmax_fwd(
softmax_logics,
@@ -479,6 +478,7 @@ if HAS_TRITON:
)
return
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@triton.jit
def _token_attn_2_kernel(
Prob,