mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[inference] add reference and fix some bugs (#4937)
* add reference and fix some bugs * update gptq init --------- Co-authored-by: Xu Kai <xukai16@foxamil.com>
This commit is contained in:
@@ -267,6 +267,7 @@ def cai_gptq_matmul_248_kernel(
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
|
||||
@autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
|
@@ -13,10 +13,10 @@ except ImportError:
|
||||
|
||||
if HAS_TRITON:
|
||||
"""
|
||||
this function is modified from
|
||||
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
|
||||
this functions are modified from https://github.com/ModelTC/lightllm
|
||||
"""
|
||||
|
||||
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
||||
@triton.jit
|
||||
def _context_flash_attention_kernel(
|
||||
Q,
|
||||
@@ -145,20 +145,16 @@ if HAS_TRITON:
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
return
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_llama_context_attn_fwd(
|
||||
q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len
|
||||
):
|
||||
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk, "context process only supports equal query, key, value length"
|
||||
assert Lk == Lv, "context process only supports equal query, key, value length"
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
BLOCK_N = 128
|
||||
sm_scale = 1.0 / math.sqrt(Lk)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
@@ -203,6 +199,7 @@ if HAS_TRITON:
|
||||
)
|
||||
return
|
||||
|
||||
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
||||
@triton.jit
|
||||
def _token_attn_1_kernel(
|
||||
Q,
|
||||
@@ -264,6 +261,7 @@ if HAS_TRITON:
|
||||
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
|
||||
return
|
||||
|
||||
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
||||
@triton.jit
|
||||
def _token_attn_1_alibi_kernel(
|
||||
Q,
|
||||
@@ -413,6 +411,7 @@ if HAS_TRITON:
|
||||
)
|
||||
return
|
||||
|
||||
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
|
||||
@triton.jit
|
||||
def _token_attn_softmax_fwd(
|
||||
softmax_logics,
|
||||
@@ -479,6 +478,7 @@ if HAS_TRITON:
|
||||
)
|
||||
return
|
||||
|
||||
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
||||
@triton.jit
|
||||
def _token_attn_2_kernel(
|
||||
Prob,
|
||||
|
Reference in New Issue
Block a user