From 28052a71fb76745ee861f1aa14884a7495c9590f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cuiqing=20Li=20=28=E6=9D=8E=E5=B4=94=E5=8D=BF=29?= Date: Thu, 16 Nov 2023 16:43:15 +0800 Subject: [PATCH] [Kernels]Update triton kernels into 2.1.0 (#5046) * update flash-context-attention * adding kernels * fix * reset * add build script * add building process * add llama2 exmaple * add colossal-llama2 test * clean * fall back test setting * fix test file * clean * clean * clean --------- Co-authored-by: cuiqing.li --- colossalai/inference/README.md | 17 +- colossalai/inference/build.sh | 24 ++ .../tensor_parallel/modeling/llama.py | 46 +-- colossalai/kernel/triton/context_attention.py | 371 ++++++++++++------ .../kernel/triton/token_attention_kernel.py | 20 +- examples/inference/bench_llama.py | 1 - examples/inference/colossal_llama2_demo.py | 81 ++++ requirements/requirements-test.txt | 2 +- .../triton/test_llama_context_attention.py | 1 - 9 files changed, 392 insertions(+), 171 deletions(-) create mode 100644 colossalai/inference/build.sh create mode 100644 examples/inference/colossal_llama2_demo.py diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index cf5dbf245..ce9b6658b 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -69,11 +69,11 @@ cd lightllm git checkout 28c1267cfca536b7b4f28e921e03de735b003039 pip3 install -e . -# also, install xformers from source: -pip install ninja -# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types -pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers +# install flash-attention +git clone -recursive https://github.com/Dao-AILab/flash-attention +cd flash-attention +pip install -e . ``` ### Docker @@ -95,10 +95,11 @@ cd lightllm git checkout 28c1267cfca536b7b4f28e921e03de735b003039 pip3 install -e . -# install xformers from source -pip install ninja -# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types -pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers +# install flash-attention +git clone -recursive https://github.com/Dao-AILab/flash-attention +cd flash-attention +pip install -e . + ``` ### Dive into fast-inference! diff --git a/colossalai/inference/build.sh b/colossalai/inference/build.sh new file mode 100644 index 000000000..6a73f6f0b --- /dev/null +++ b/colossalai/inference/build.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +# install triton +pip install triton +pip install transformers + +# install lightllm and flash-attention +mkdir 3rdParty +cd 3rdParty +git clone https://github.com/ModelTC/lightllm +cd lightllm +git checkout 28c1267cfca536b7b4f28e921e03de735b003039 +pip install -e . +cd .. + +git clone -recursive https://github.com/Dao-AILab/flash-attention +cd flash-attention +pip install -e . + +cd ../../ + + + + diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 62c2aad3c..448943b12 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -8,15 +8,10 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards - from ._utils import copy_kv_to_mem_cache - try: - from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd as lightllm_llama2_context_attention_fwd, - ) from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd as lightllm_context_attention_fwd, + context_attention_fwd as lightllm_llama_context_attention_fwd, ) from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd @@ -56,32 +51,20 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): def llama_triton_context_attention( query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1 ): - if num_key_value_groups == 1: - if HAS_LIGHTLLM_KERNEL is False: - llama_context_attn_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, - infer_state.max_len_in_batch, - ) - else: - lightllm_context_attention_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, - infer_state.max_len_in_batch, - ) + # if num_key_value_groups == 1: + if HAS_LIGHTLLM_KERNEL is False: + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) else: - assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model" - lightllm_llama2_context_attention_fwd( + lightllm_llama_context_attention_fwd( query_states, key_states, value_states, @@ -107,6 +90,7 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key # infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch, ) + else: Llama2TokenAttentionForwards.token_attn( query_states, diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 5ce6f2c21..1ad7a80eb 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -15,127 +15,223 @@ if HAS_TRITON: this function is modified from https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 """ + if triton.__version__ < "2.1.0": + @triton.jit + def _context_flash_attention_kernel( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + TMP, + alibi_ptr, + Out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, + # suggtest set-up 64, 128, 256, 512 + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + batch_id = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) - @triton.jit - def _context_flash_attention_kernel( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, - TMP, - alibi_ptr, - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_tmp_b, - stride_tmp_h, - stride_tmp_s, - # suggtest set-up 64, 128, 256, 512 - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - batch_id = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # get batch info + cur_batch_seq_len = tl.load(B_Seqlen + batch_id) + cur_batch_start_index = tl.load(B_Start_Loc + batch_id) + block_start_loc = BLOCK_M * start_m - # get batch info - cur_batch_seq_len = tl.load(B_Seqlen + batch_id) - cur_batch_start_index = tl.load(B_Start_Loc + batch_id) - block_start_loc = BLOCK_M * start_m - - load_p_ptrs = ( - Q - + (cur_batch_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd - v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd - t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s - - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - if alibi_ptr is not None: - alibi_m = tl.load(alibi_ptr + cur_head) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - k = tl.load( - k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, + load_p_ptrs = ( + Q + + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd ) + q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale + k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) if alibi_ptr is not None: - alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) - qk -= alibi_loc * alibi_m + alibi_m = tl.load(alibi_ptr + cur_head) - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k = tl.load( + k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + if alibi_ptr is not None: + alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) + qk -= alibi_loc * alibi_m + + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_o = ( + (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + else: + @triton.jit + def _context_flash_attention_kernel_2( + Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen, + Out, + kv_group_num, + stride_qbs, stride_qh, stride_qd, + stride_kbs, stride_kh, stride_kd, + stride_vbs, stride_vh, stride_vd, + stride_obs, stride_oh, stride_od, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + if kv_group_num is not None: + cur_kv_head = cur_head // kv_group_num - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - off_o = ( - (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + if kv_group_num is None or kv_group_num == 1: + off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + else: + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + + q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + if Alibi is not None: + alibi_m = tl.load(Alibi + cur_head) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + if Alibi is not None: + alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) + qk -= alibi_loc * alibi_m + + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return @torch.no_grad() def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None): @@ -152,10 +248,9 @@ if HAS_TRITON: grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) num_warps = 4 if Lk <= 64 else 8 - - tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) if triton.__version__ < "2.1.0": + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) _context_flash_attention_kernel[grid]( q, k, @@ -189,7 +284,28 @@ if HAS_TRITON: num_stages=1, ) else: - raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0") + _context_flash_attention_kernel_2[grid]( + q, k, v, sm_scale, alibi, b_start_loc, b_seq_len, + o, + None, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) return @@ -220,7 +336,7 @@ if HAS_TRITON: b_start_loc, b_seq_len, tmp, - None, + None, o, q.stride(0), q.stride(1), @@ -244,6 +360,33 @@ if HAS_TRITON: num_stages=1, ) else: - raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0") + kv_group_num = q.shape[1] // k.shape[1] + _context_flash_attention_kernel_2[grid]( + q, + k, + v, + sm_scale, + None, + b_start_loc, + b_seq_len, + o, + kv_group_num, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1,) return \ No newline at end of file diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index 8dc919bad..de2003748 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -13,17 +13,7 @@ except ImportError: print("please install triton from https://github.com/openai/triton") try: - from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import ( - token_att_fwd as lightllm_llama2_token_att_fwd, - ) - from lightllm.models.llama2.triton_kernel.token_attention_nopad_reduceV import ( - token_att_fwd2 as lightllm_llama2_token_att_fwd2, - ) - from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import ( - token_softmax_fwd as lightllm_llama2_token_softmax_fwd, - ) - - from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2 + from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2 from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd @@ -72,7 +62,7 @@ if HAS_TRITON: lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) att_m_tensor = None - lightllm_llama_token_att_fw2( + lightllm_llama_token_att_fwd2( prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch ) prob = None @@ -203,7 +193,7 @@ class Llama2TokenAttentionForwards: calcu_shape1 = (batch_size, head_num, head_dim) att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - lightllm_llama2_token_att_fwd( + lightllm_llama_token_att_fwd( q, k, att_m_tensor, @@ -215,12 +205,12 @@ class Llama2TokenAttentionForwards: if triton.__version__ == "2.0.0": prob = torch.empty_like(att_m_tensor) - lightllm_llama2_token_softmax_fwd( + lightllm_llama_token_softmax_fwd( att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch ) att_m_tensor = None - lightllm_llama2_token_att_fwd2( + lightllm_llama_token_att_fwd2( prob, v, attn_out.view(calcu_shape1), diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 4db32c71a..c6eb3a5c6 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -28,7 +28,6 @@ def run_llama_test(args): tokenizer.pad_token_id = tokenizer.unk_token_id model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) model = model.half() - model.config shard_config = ShardConfig( enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True} diff --git a/examples/inference/colossal_llama2_demo.py b/examples/inference/colossal_llama2_demo.py new file mode 100644 index 000000000..72abab2a4 --- /dev/null +++ b/examples/inference/colossal_llama2_demo.py @@ -0,0 +1,81 @@ +import os +import warnings + +import torch +import torch.distributed as dist +import argparse +from packaging import version + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from transformers import AutoModelForCausalLM, AutoTokenizer + + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +TPSIZE = 1 +BATCH_SIZE = 4 +MAX_INPUT_LEN = 32 +MAX_OUTPUT_LEN = 128 + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') + + +@parameterize('test_config', [{ + 'tp_size': TPSIZE, +}]) +def run_llama_test(test_config, args): + + model_path = args.path + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + tokenizer.pad_token_id = tokenizer.unk_token_id + model = AutoModelForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) + model = model.half() + + text = ["Introduce London.", "What is the genus of Poodle?"] + input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True) + + print(input_ids) + + shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, + extra_kwargs={"inference_only": True}) + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + + generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) + outputs = infer_engine.generate(input_ids, **generate_kwargs) + + assert outputs is not None + + if not dist.is_initialized() or dist.get_rank() == 0: + for o in outputs: + output_text = tokenizer.decode(o) + print(output_text) + + +def check_llama(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test(args=args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(args): + spawn(check_llama, args.tp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--path", type=str, default = "hpcai-tech/Colossal-LLaMA-2-7b-base", help="Model path") + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=32, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") + parser.add_argument( + "--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"] + ) + args = parser.parse_args() + test_llama(args) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index f54b13c7e..61b58055e 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -12,7 +12,7 @@ torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package torchrec==0.2.0 contexttimer einops -triton==2.0.0.dev20221202 +triton==2.1.0 requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 SentencePiece ninja diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py index be6de6db2..95fe50cf1 100644 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -41,7 +41,6 @@ def test_llama_context_attention(): llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len) torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - assert torch.allclose( torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3 ), "outputs from triton and torch are not matched"