From 2bb92243d4151873d75a9d6d9c2275b390e1716a Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Tue, 5 Dec 2023 15:12:57 +0800 Subject: [PATCH] [Inference/NFC] Clean outdated inference tests and deprecated kernels (#5159) * [inference/nfc] remove outdated inference tests * remove outdated kernel tests * remove deprecated triton kernels * remove imports from deprecated kernels --- colossalai/kernel/triton/__init__.py | 12 - colossalai/kernel/triton/context_attention.py | 393 ----------- .../kernel/triton/copy_kv_cache_dest.py | 71 -- colossalai/kernel/triton/flash_decoding.py | 50 -- .../triton/int8_rotary_embedding_kernel.py | 117 ---- .../kernel/triton/self_attention_nofusion.py | 164 ----- colossalai/kernel/triton/smooth_attention.py | 652 ------------------ .../kernel/triton/token_attention_kernel.py | 238 ------- tests/test_infer/test_hybrid_bloom.py | 121 ---- tests/test_infer/test_hybrid_chatglm2.py | 129 ---- tests/test_infer/test_hybrid_llama.py | 126 ---- tests/test_infer/test_kvcache_manager.py | 66 -- .../triton/test_bloom_context_attention.py | 52 -- .../triton/test_copy_kv_dest.py | 39 -- .../triton/test_llama_context_attention.py | 50 -- .../triton/test_self_attention_nonfusion.py | 143 ---- .../triton/test_token_attn_fwd.py | 72 -- .../triton/test_token_softmax.py | 48 -- 18 files changed, 2543 deletions(-) delete mode 100644 colossalai/kernel/triton/context_attention.py delete mode 100644 colossalai/kernel/triton/copy_kv_cache_dest.py delete mode 100644 colossalai/kernel/triton/flash_decoding.py delete mode 100644 colossalai/kernel/triton/int8_rotary_embedding_kernel.py delete mode 100644 colossalai/kernel/triton/self_attention_nofusion.py delete mode 100644 colossalai/kernel/triton/smooth_attention.py delete mode 100644 colossalai/kernel/triton/token_attention_kernel.py delete mode 100644 tests/test_infer/test_hybrid_bloom.py delete mode 100644 tests/test_infer/test_hybrid_chatglm2.py delete mode 100644 tests/test_infer/test_hybrid_llama.py delete mode 100644 tests/test_infer/test_kvcache_manager.py delete mode 100644 tests/test_infer_ops/triton/test_bloom_context_attention.py delete mode 100644 tests/test_infer_ops/triton/test_copy_kv_dest.py delete mode 100644 tests/test_infer_ops/triton/test_llama_context_attention.py delete mode 100644 tests/test_infer_ops/triton/test_self_attention_nonfusion.py delete mode 100644 tests/test_infer_ops/triton/test_token_attn_fwd.py delete mode 100644 tests/test_infer_ops/triton/test_token_softmax.py diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 20da71d39..85c4d911b 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -8,24 +8,12 @@ except ImportError: # There may exist import error even if we have triton installed. if HAS_TRITON: - from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd - from .copy_kv_cache_dest import copy_kv_cache_to_dest from .fused_layernorm import layer_norm from .gptq_triton import gptq_fused_linear_triton - from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd - from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd from .softmax import softmax - from .token_attention_kernel import token_attention_fwd __all__ = [ - "llama_context_attn_fwd", - "bloom_context_attn_fwd", "softmax", "layer_norm", - "copy_kv_cache_to_dest", - "token_attention_fwd", "gptq_fused_linear_triton", - "int8_rotary_embedding_fwd", - "smooth_llama_context_attn_fwd", - "smooth_token_attention_fwd", ] diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py deleted file mode 100644 index 3d9a23d2f..000000000 --- a/colossalai/kernel/triton/context_attention.py +++ /dev/null @@ -1,393 +0,0 @@ -import math - -import torch - -try: - import triton - import triton.language as tl - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -if HAS_TRITON: - """ - this function is modified from - https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 - """ - if triton.__version__ < "2.1.0": - @triton.jit - def _context_flash_attention_kernel( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, - TMP, - alibi_ptr, - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_tmp_b, - stride_tmp_h, - stride_tmp_s, - # suggtest set-up 64, 128, 256, 512 - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - batch_id = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - - # get batch info - cur_batch_seq_len = tl.load(B_Seqlen + batch_id) - cur_batch_start_index = tl.load(B_Start_Loc + batch_id) - block_start_loc = BLOCK_M * start_m - - load_p_ptrs = ( - Q - + (cur_batch_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd - v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd - t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s - - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - if alibi_ptr is not None: - alibi_m = tl.load(alibi_ptr + cur_head) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - k = tl.load( - k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - - if alibi_ptr is not None: - alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) - qk -= alibi_loc * alibi_m - - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_o = ( - (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - else: - # this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11 - @triton.jit - def _context_flash_attention_kernel_2( - Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen, - Out, - kv_group_num, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - if kv_group_num is not None: - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd - if kv_group_num is None or kv_group_num == 1: - off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd - else: - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - if Alibi is not None: - alibi_m = tl.load(Alibi + cur_head) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - - if Alibi is not None: - alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) - qk -= alibi_loc * alibi_m - - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - @torch.no_grad() - def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None): - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk, "context process only supports equal query, key, value length" - assert Lk == Lv, "context process only supports equal query, key, value length" - assert Lk in {16, 32, 64, 128} - - sm_scale = 1.0 / math.sqrt(Lk) - batch, head = b_seq_len.shape[0], q.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - - num_warps = 4 if Lk <= 64 else 8 - - if triton.__version__ < "2.1.0": - tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) - _context_flash_attention_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - alibi, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - # manually setting this blcok num, we can use tuning config to futher speed-up - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - else: - _context_flash_attention_kernel_2[grid]( - q, k, v, sm_scale, alibi, b_start_loc, b_seq_len, - o, - None, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - - return - - @torch.no_grad() - def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk, "context process only supports equal query, key, value length" - assert Lk == Lv, "context process only supports equal query, key, value length" - assert Lk in {16, 32, 64, 128} - - sm_scale = 1.0 / math.sqrt(Lk) - batch, head = b_seq_len.shape[0], q.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - - tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) - num_warps = 4 if Lk <= 64 else 8 - # num_warps = 4 - - if triton.__version__ < "2.1.0": - _context_flash_attention_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - None, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - else: - kv_group_num = q.shape[1] // k.shape[1] - _context_flash_attention_kernel_2[grid]( - q, - k, - v, - sm_scale, - None, - b_start_loc, - b_seq_len, - o, - kv_group_num, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1,) - - return \ No newline at end of file diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py deleted file mode 100644 index b8e6ab1d0..000000000 --- a/colossalai/kernel/triton/copy_kv_cache_dest.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch - -try: - import triton - import triton.language as tl - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -if HAS_TRITON: - # adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py - @triton.jit - def _fwd_copy_kv_cache_dest( - kv_cache_ptr, - dest_index_ptr, - out, - stride_k_bs, - stride_k_h, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_d, - head_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr, - ): - cur_index = tl.program_id(0) - offs_h = tl.arange(0, BLOCK_HEAD) - offs_d = tl.arange(0, BLOCK_DMODEL) - - dest_index = tl.load(dest_index_ptr + cur_index) - - cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :] - k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets - - o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] - o_ptrs = out + dest_index * stride_o_bs + o_offsets - - k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0) - tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num) - return - - # adepted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py - @torch.no_grad() - def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): - seq_len = dest_index_ptr.shape[0] - head_num = k_ptr.shape[1] - head_dim = k_ptr.shape[2] - assert head_num == out.shape[1], "head_num should be the same for k_ptr and out" - assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out" - - num_warps = 2 - _fwd_copy_kv_cache_dest[(seq_len,)]( - k_ptr, - dest_index_ptr, - out, - k_ptr.stride(0), - k_ptr.stride(1), - k_ptr.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - head_num, - BLOCK_DMODEL=head_dim, - BLOCK_HEAD=triton.next_power_of_2(head_num), - num_warps=num_warps, - num_stages=2, - ) - return diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py deleted file mode 100644 index 9b7b27fa1..000000000 --- a/colossalai/kernel/triton/flash_decoding.py +++ /dev/null @@ -1,50 +0,0 @@ -# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py -import torch -try: - from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1 - from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2 - HAS_LIGHTLLM_KERNEL = True -except: - print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8") - HAS_LIGHTLLM_KERNEL = False - - -if HAS_LIGHTLLM_KERNEL: - def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v): - BLOCK_SEQ = 256 - batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch - - - calcu_shape1 = (batch_size, q_head_num, head_dim) - - if getattr(infer_state, 'mid_o', None) is None: - infer_state.mid_o = torch.empty([batch_size, - q_head_num, - max_len_in_batch // BLOCK_SEQ + 1, - head_dim], - dtype=torch.float32, - device="cuda") - infer_state.mid_o_logexpsum = torch.empty([batch_size, - q_head_num, - max_len_in_batch // BLOCK_SEQ + 1], - dtype=torch.float32, - device="cuda") - - mid_o = infer_state.mid_o - mid_o_logexpsum = infer_state.mid_o_logexpsum - - flash_decode_stage1(q.view(calcu_shape1), - cache_k, - cache_v, - infer_state.block_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - mid_o, - mid_o_logexpsum, - BLOCK_SEQ) - flash_decode_stage2(mid_o, - mid_o_logexpsum, - infer_state.seq_len, - o_tensor.view(calcu_shape1), - BLOCK_SEQ) diff --git a/colossalai/kernel/triton/int8_rotary_embedding_kernel.py b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py deleted file mode 100644 index 537dd164d..000000000 --- a/colossalai/kernel/triton/int8_rotary_embedding_kernel.py +++ /dev/null @@ -1,117 +0,0 @@ -# Adapted from ModelTC https://github.com/ModelTC/lightllm -import torch -import triton -import triton.language as tl - - -@triton.jit -def _rotary_kernel( - q, - input_scale, - output_scale, - Cos, - Sin, - q_bs_stride, - q_h_stride, - q_d_stride, - cos_bs_stride, - cos_d_stride, - total_len, - HEAD_NUM: tl.constexpr, - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - HEAD_DIM: tl.constexpr, -): - current_head_index = tl.program_id(0) - current_seq_index = tl.program_id(1) - - dim_range0 = tl.arange(0, HEAD_DIM // 2) - dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) - - current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - off_q0 = ( - current_seq_range[:, None, None] * q_bs_stride - + current_head_range[None, :, None] * q_h_stride - + dim_range0[None, None, :] * q_d_stride - ) - off_q1 = ( - current_seq_range[:, None, None] * q_bs_stride - + current_head_range[None, :, None] * q_h_stride - + dim_range1[None, None, :] * q_d_stride - ) - - off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride - - q0 = tl.load( - q + off_q0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0, - ) - q1 = tl.load( - q + off_q1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) - - q0 = q0.to(tl.float32) * input_scale - q1 = q1.to(tl.float32) * input_scale - - out0 = (q0 * cos - q1 * sin) / output_scale - out1 = (q0 * sin + q1 * cos) / output_scale - - out0 = out0.to(tl.int8) - out1 = out1.to(tl.int8) - - tl.store( - q + off_q0, - out0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - ) - tl.store( - q + off_q1, - out1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - ) - - return - - -@torch.no_grad() -def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale): - total_len = q.shape[0] - head_num = q.shape[1] - head_dim = q.shape[2] - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - BLOCK_HEAD = 4 - BLOCK_SEQ = 32 - grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - _rotary_kernel[grid]( - q, - input_scale, - output_scale, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - cos.stride(0), - cos.stride(1), - total_len, - HEAD_NUM=head_num, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - HEAD_DIM=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/colossalai/kernel/triton/self_attention_nofusion.py b/colossalai/kernel/triton/self_attention_nofusion.py deleted file mode 100644 index 50d6786bd..000000000 --- a/colossalai/kernel/triton/self_attention_nofusion.py +++ /dev/null @@ -1,164 +0,0 @@ -import torch - -try: - import triton - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -if HAS_TRITON: - from .qkv_matmul_kernel import qkv_gemm_4d_kernel - from .softmax import softmax_kernel - - # adpeted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py#L312 - def self_attention_forward_without_fusion( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float - ): - r"""A function to do QKV Attention calculation by calling GEMM and softmax triton kernels - Args: - q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) - k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) - v (torch.Tensor): V embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) - input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len) - scale: the float scale value which is used to multiply with Q*K^T before doing softmax - - Return: - output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size) - """ - assert len(q.shape) == 4, "the shape of q val must be 4" - batches, M, H, K = q.shape - assert q.shape == k.shape, "the shape of q and the shape of k must be equal" - assert q.shape == v.shape, "the shape of q and the shape of v must be equal" - assert q.shape[-1] == k.shape[-1], "the last dimension of q and k must be equal" - - N = k.shape[1] - - # head_size * num_of_head - d_model = q.shape[-1] * q.shape[-2] - - score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype) - - grid = lambda meta: ( - batches, - H, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ) - - qkv_gemm_4d_kernel[grid]( - q, - k, - score_output, - M, - N, - K, - q.stride(0), - q.stride(2), - q.stride(1), - q.stride(3), - k.stride(0), - k.stride(2), - k.stride(3), - k.stride(1), - score_output.stride(0), - score_output.stride(1), - score_output.stride(2), - score_output.stride(3), - scale=scale, - # currently manually setting, later on we can use auto-tune config to match best setting - BLOCK_SIZE_M=64, - BLOCK_SIZE_N=32, - BLOCK_SIZE_K=32, - GROUP_SIZE_M=8, - ) - - softmax_output = torch.empty(score_output.shape, device=score_output.device, dtype=score_output.dtype) - score_output_shape = score_output.shape - - score_output = score_output.view(-1, score_output.shape[-1]) - n_rows, n_cols = score_output.shape - - if n_rows <= 350000: - block_size = max(triton.next_power_of_2(n_cols), 2) - num_warps = 4 - if block_size >= 4096: - num_warps = 16 - elif block_size >= 2048: - num_warps = 8 - else: - num_warps = 4 - - softmax_kernel[(n_rows,)]( - softmax_output, - score_output, - score_output.stride(0), - n_cols, - mask_ptr=input_mask, - num_warps=num_warps, - BLOCK_SIZE=block_size, - ) - - else: - # NOTE: change softmax kernel functions to make it suitable for large size dimension - softmax_output = torch.nn.functional.softmax(score_output, dim=-1) - softmax_output = softmax_output.view(*score_output_shape) - - batches, H, M, K = softmax_output.shape - N = v.shape[-1] - - output = torch.empty((batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype) - - grid = lambda meta: ( - batches, - H, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ) - - qkv_gemm_4d_kernel[grid]( - softmax_output, - v, - output, - M, - N, - K, - softmax_output.stride(0), - softmax_output.stride(1), - softmax_output.stride(2), - softmax_output.stride(3), - v.stride(0), - v.stride(2), - v.stride(1), - v.stride(3), - output.stride(0), - output.stride(2), - output.stride(1), - output.stride(3), - BLOCK_SIZE_M=128, - BLOCK_SIZE_N=64, - BLOCK_SIZE_K=64, - GROUP_SIZE_M=8, - scale=-1, - ) - return output.view(batches, -1, d_model) - - # modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/attention.py#L212 - def self_attention_compute_using_triton( - qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False - ): - assert qkv.is_contiguous() - assert alibi is None, "current triton self-attention does not support alibi" - batches = qkv.shape[0] - d_model = qkv.shape[-1] // 3 - num_of_heads = d_model // head_size - - q = qkv[:, :, :d_model] - k = qkv[:, :, d_model : d_model * 2] - v = qkv[:, :, d_model * 2 :] - q = q.view(batches, -1, num_of_heads, head_size) - k = k.view(batches, -1, num_of_heads, head_size) - v = v.view(batches, -1, num_of_heads, head_size) - - data_output_triton = self_attention_forward_without_fusion(q, k, v, input_mask, scale) - - return data_output_triton diff --git a/colossalai/kernel/triton/smooth_attention.py b/colossalai/kernel/triton/smooth_attention.py deleted file mode 100644 index 071de58e2..000000000 --- a/colossalai/kernel/triton/smooth_attention.py +++ /dev/null @@ -1,652 +0,0 @@ -import math - -import torch - -try: - import triton - import triton.language as tl - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -if HAS_TRITON: - """ - this functions are modified from https://github.com/ModelTC/lightllm - """ - - # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py - @triton.jit - def _context_flash_attention_kernel( - Q, - K, - V, - q_input_scale, - k_input_scale, - v_input_scale, - pv_output_scale, - sm_scale, - B_Start_Loc, - B_Seqlen, - TMP, - alibi_ptr, - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_tmp_b, - stride_tmp_h, - stride_tmp_s, - # suggtest set-up 64, 128, 256, 512 - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - batch_id = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - - # get batch info - cur_batch_seq_len = tl.load(B_Seqlen + batch_id) - cur_batch_start_index = tl.load(B_Start_Loc + batch_id) - block_start_loc = BLOCK_M * start_m - - load_p_ptrs = ( - Q - + (cur_batch_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - q = q.to(tl.float16) * q_input_scale.to(tl.float16) - - k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd - v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd - t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s - - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - if alibi_ptr is not None: - alibi_m = tl.load(alibi_ptr + cur_head) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - k = tl.load( - k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - k = k.to(tl.float16) * k_input_scale.to(tl.float16) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - - if alibi_ptr is not None: - alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) - qk -= alibi_loc * alibi_m - - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - v = v.to(tl.float16) * v_input_scale.to(tl.float16) - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8) - off_o = ( - (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - @torch.no_grad() - def smooth_llama_context_attn_fwd( - q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len - ): - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk, "context process only supports equal query, key, value length" - assert Lk == Lv, "context process only supports equal query, key, value length" - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / math.sqrt(Lk) - batch, head = b_seq_len.shape[0], q.shape[1] - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - - tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) - num_warps = 4 if Lk <= 64 else 8 - - _context_flash_attention_kernel[grid]( - q, - k, - v, - q_input_scale, - k_input_scale, - v_input_scale, - pv_output_scale, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - None, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py - @triton.jit - def _token_attn_1_kernel( - Q, - K, - q_input_scale, - k_input_scale, - sm_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - q_batch_stride, - q_head_stride, - q_head_dim_stride, - k_batch_stride, - k_head_stride, - k_head_dim_stride, - attn_head_stride, - attn_batch_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - start_n = tl.program_id(2) - - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = max_kv_cache_len - - off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) - q = q.to(tl.float16) * q_input_scale.to(tl.float16) - offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load( - kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride - k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - k = k.to(tl.float16) * k_input_scale.to(tl.float16) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride - tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) - return - - # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py - @triton.jit - def _token_attn_1_alibi_kernel( - Q, - K, - q_input_scale, - k_input_scale, - sm_scale, - alibi, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - q_batch_stride, - q_head_stride, - q_head_dim_stride, - k_batch_stride, - k_head_stride, - k_head_dim_stride, - attn_head_stride, - attn_batch_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - start_n = tl.program_id(2) - - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = max_kv_cache_len - - off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - alibi_m = tl.load(alibi + current_head) - q = tl.load(Q + off_q + start_mark) - q = q.to(tl.float16) * q_input_scale.to(tl.float16) - - offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load( - kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride - k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - k = k.to(tl.float16) * k_input_scale.to(tl.float16) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) - off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride - tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) - return - - @torch.no_grad() - def token_attn_fwd_1( - q, - k, - attn_out, - q_input_scale, - k_input_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - alibi=None, - ): - BLOCK = 32 - # shape constraints - q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] - assert q_head_dim == k_head_dim - assert k_head_dim in {16, 32, 64, 128} - sm_scale = 1.0 / (k_head_dim**0.5) - - batch, head_num = kv_cache_loc.shape[0], q.shape[1] - - grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK)) - - num_warps = 4 if k_head_dim <= 64 else 8 - num_warps = 2 - - if alibi is not None: - _token_attn_1_alibi_kernel[grid]( - q, - k, - q_input_scale, - k_input_scale, - sm_scale, - alibi, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - attn_out.stride(0), - attn_out.stride(1), - HEAD_DIM=k_head_dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - else: - _token_attn_1_kernel[grid]( - q, - k, - q_input_scale, - k_input_scale, - sm_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - attn_out.stride(0), - attn_out.stride(1), - HEAD_DIM=k_head_dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py - @triton.jit - def _token_attn_softmax_fwd( - softmax_logics, - kv_cache_start_loc, - kv_cache_seqlen, - softmax_prob_out, - logics_head_dim_stride, - logics_batch_stride, - prob_head_dim_stride, - prob_batch_stride, - BLOCK_SIZE: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - - col_offsets = tl.arange(0, BLOCK_SIZE) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - row = tl.load( - softmax_logics - + current_head * logics_head_dim_stride - + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, - mask=col_offsets < current_batch_seq_len, - other=-float("inf"), - ).to(tl.float32) - - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - - tl.store( - softmax_prob_out - + current_head * prob_head_dim_stride - + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, - softmax_output, - mask=col_offsets < current_batch_seq_len, - ) - return - - @torch.no_grad() - def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len): - BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len) - batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0] - - num_warps = 4 - if BLOCK_SIZE >= 2048: - num_warps = 8 - if BLOCK_SIZE >= 4096: - num_warps = 16 - - _token_attn_softmax_fwd[(batch, head_num)]( - softmax_logics, - kv_cache_start_loc, - kv_cache_seqlen, - softmax_prob_out, - softmax_logics.stride(0), - softmax_logics.stride(1), - softmax_prob_out.stride(0), - softmax_prob_out.stride(1), - num_warps=num_warps, - BLOCK_SIZE=BLOCK_SIZE, - ) - return - - # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py - @triton.jit - def _token_attn_2_kernel( - Prob, - V, - attn_out, - v_input_scale, - pv_output_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - prob_head_dim_stride, - prob_batch_stride, - v_batch_stride, - v_head_stride, - v_head_dim_stride, - attn_out_batch_stride, - attn_out_head_stride, - attn_out_head_dim_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride - p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride - v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride - - acc = tl.zeros([HEAD_DIM], dtype=tl.float32) - for start_n in range(0, current_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load( - Prob + p_offs + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0, - ) - v_loc = tl.load( - kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0, - ) - v_value = tl.load( - V + v_offs + v_loc[:, None] * v_batch_stride, - mask=(start_n + offs_n[:, None]) < current_batch_seq_len, - other=0.0, - ) - v_value = v_value.to(tl.float16) * v_input_scale.to(tl.float16) - acc += tl.sum(p_value[:, None] * v_value, 0) - - acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8) - off_o = ( - current_batch * attn_out_batch_stride - + current_head * attn_out_head_stride - + offs_d * attn_out_head_dim_stride - ) - out_ptrs = attn_out + off_o - tl.store(out_ptrs, acc) - return - - @torch.no_grad() - def token_attn_fwd_2( - prob, - v, - attn_out, - v_input_scale, - pv_output_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - ): - if triton.__version__ >= "2.1.0": - BLOCK = 128 - else: - BLOCK = 64 - batch, head = kv_cache_loc.shape[0], v.shape[1] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - - _token_attn_2_kernel[grid]( - prob, - v, - attn_out, - v_input_scale, - pv_output_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - attn_out.stride(0), - attn_out.stride(1), - attn_out.stride(2), - HEAD_DIM=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - @torch.no_grad() - def smooth_token_attention_fwd( - q, - k, - v, - attn_out, - q_input_scale, - k_input_scale, - v_input_scale, - pv_output_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - alibi=None, - ): - head_num = k.shape[1] - batch_size = kv_cache_seq_len.shape[0] - calcu_shape1 = (batch_size, head_num, k.shape[2]) - total_token_num = k.shape[0] - - att_m_tensor = torch.empty((head_num, total_token_num), dtype=torch.float32, device="cuda") - - token_attn_fwd_1( - q.view(calcu_shape1), - k, - att_m_tensor, - q_input_scale, - k_input_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - alibi=alibi, - ) - - prob = torch.empty_like(att_m_tensor) - - token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) - att_m_tensor = None - token_attn_fwd_2( - prob, - v, - attn_out.view(calcu_shape1), - v_input_scale, - pv_output_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - ) - - prob = None - - return diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py deleted file mode 100644 index de2003748..000000000 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ /dev/null @@ -1,238 +0,0 @@ -# Adapted from ModelTC https://github.com/ModelTC/lightllm - - -import torch - -try: - import triton - import triton.language as tl - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -try: - from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2 - from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd - from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd - from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd - - HAS_TRITON_TOKEN_ATTENTION = True -except ImportError: - print("unable to import lightllm kernels") - HAS_TRITON_TOKEN_ATTENTION = False - -if HAS_TRITON: - - @torch.no_grad() - def token_attention_fwd( - q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None - ): - head_num = k.shape[1] - batch_size = kv_cache_seq_len.shape[0] - calcu_shape1 = (batch_size, head_num, k.shape[2]) - total_token_num = k.shape[0] - - att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - - if alibi is None: - lightllm_llama_token_att_fwd( - q.view(calcu_shape1), - k, - att_m_tensor, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - ) - else: - lightllm_bloom_token_att_fwd( - q.view(calcu_shape1), - k, - att_m_tensor, - alibi, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - ) - - prob = torch.empty_like(att_m_tensor) - - lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) - att_m_tensor = None - lightllm_llama_token_att_fwd2( - prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch - ) - prob = None - return - - -class Llama2TokenAttentionForwards: - @staticmethod - @triton.jit - - # this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L8 - def _fwd_kernel( - Logics, - V, - Out, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, - stride_logic_h, - stride_logic_bs, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_b_loc_b, - stride_b_loc_s, - other_kv_index, # avoid nan information - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - - off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - off_b_loc = cur_batch * stride_b_loc_b + (max_input_len - cur_batch_seq_len) * stride_b_loc_s - - v_ptrs = V + off_v - - e_max = float("-inf") - e_sum = 0.0 - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - v_index = tl.load( - B_Loc + off_b_loc + (start_n + offs_n) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_seq_len, - other=other_kv_index, - ) - - qk = tl.load( - Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, - mask=start_n + offs_n < cur_batch_seq_len, - other=float("-inf"), - ) - - n_e_max = tl.maximum(tl.max(qk, 0), e_max) - old_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max) - e_sum = e_sum * old_scale + tl.sum(p, 0) - v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) - acc = acc * old_scale + tl.sum(p[:, None] * v, 0) - e_max = n_e_max - - acc = acc / e_sum - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - # this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L36 - @staticmethod - @torch.no_grad() - def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index): - BLOCK = 64 - batch, head = b_seq_len.shape[0], logics.shape[0] - grid = (batch, head) - kv_group_num = logics.shape[0] // v.shape[1] - - num_warps = 1 - Llama2TokenAttentionForwards._fwd_kernel[grid]( - logics, - v, - o, - b_loc, - b_start_loc, - b_seq_len, - max_input_len, - logics.stride(0), - logics.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - b_loc.stride(0), - b_loc.stride(1), - other_kv_index, - kv_group_num, - BLOCK_DMODEL=v.shape[-1], - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=3, - ) - return - - # this is the interface of llama2 attn forward - @staticmethod - @torch.no_grad() - def token_attn( - q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, other_kv_index - ): - total_token_num = k.shape[0] - batch_size, head_num, head_dim = q.shape - calcu_shape1 = (batch_size, head_num, head_dim) - att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - - lightllm_llama_token_att_fwd( - q, - k, - att_m_tensor, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - ) - - if triton.__version__ == "2.0.0": - prob = torch.empty_like(att_m_tensor) - lightllm_llama_token_softmax_fwd( - att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch - ) - att_m_tensor = None - - lightllm_llama_token_att_fwd2( - prob, - v, - attn_out.view(calcu_shape1), - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - ) - - prob = None - return - - elif triton.__version__ >= "2.1.0": - Llama2TokenAttentionForwards.token_softmax_reducev_fwd( - att_m_tensor, - v, - attn_out.view(calcu_shape1), - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - other_kv_index, - ) - else: - raise Exception("not support triton version") diff --git a/tests/test_infer/test_hybrid_bloom.py b/tests/test_infer/test_hybrid_bloom.py deleted file mode 100644 index 8cad06dca..000000000 --- a/tests/test_infer/test_hybrid_bloom.py +++ /dev/null @@ -1,121 +0,0 @@ -import importlib.util - -import pytest -import torch -import torch.distributed as dist -import transformers -from packaging import version - -import colossalai -from colossalai.inference import InferenceEngine -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn - -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") -HAS_LIGHTLLM_KERNEL = True - -if importlib.util.find_spec("lightllm") is None: - HAS_LIGHTLLM_KERNEL = False - - -def data_gen(): - input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) - return dict(input_ids=input_ids, attention_mask=attention_mask) - - -inputs = data_gen() -for k, v in inputs.items(): - if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = 16 - inputs[k] = v.to("cuda").repeat(*new_shape) - - -def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - model = transformers.BloomForCausalLM( - transformers.BloomConfig(vocab_size=20000, hidden_size=512, n_head=4, n_layer=4) - ) - - engine = InferenceEngine( - tp_size=tp_size, - pp_size=pp_size, - model=model, - max_output_len=max_output_len, - micro_batch_size=micro_batch_size, - ) - output = engine.generate(inputs) - if dist.get_rank() == 0: - assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}" - - -@parameterize("tp_size", [1]) -@parameterize("pp_size", [2]) -@parameterize("max_output_len", [4]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [2]) -@parameterize("pp_size", [2]) -@parameterize("max_output_len", [4]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [2]) -@parameterize("pp_size", [1]) -@parameterize("max_output_len", [2]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [1]) -@parameterize("pp_size", [1]) -@parameterize("max_output_len", [2]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -def check_tp_pp_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_tp_pipeline_inference_test() - - -def check_tp_or_pp_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_tp_inference_test() - run_pipeline_inference_test() - - -def check_single_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_single_inference_test - - -@pytest.mark.skipif( - not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, - reason="kv-cache manager engine requires cuda version to be higher than 11.5", -) -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_pipeline_inference(): - spawn(check_tp_pp_inference, nprocs=4) - spawn(check_tp_or_pp_inference, nprocs=2) - spawn(check_single_inference, nprocs=1) - - -if __name__ == "__main__": - test_pipeline_inference() diff --git a/tests/test_infer/test_hybrid_chatglm2.py b/tests/test_infer/test_hybrid_chatglm2.py deleted file mode 100644 index b53bb25f4..000000000 --- a/tests/test_infer/test_hybrid_chatglm2.py +++ /dev/null @@ -1,129 +0,0 @@ -import importlib.util - -import pytest -import torch -import torch.distributed as dist -from packaging import version - -import colossalai -from colossalai.inference import InferenceEngine -from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn - -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") -HAS_LIGHTLLM_KERNEL = True - -if importlib.util.find_spec("lightllm") is None: - HAS_LIGHTLLM_KERNEL = False - - -def data_gen(): - input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) - return dict(input_ids=input_ids, attention_mask=attention_mask) - - -inputs = data_gen() -for k, v in inputs.items(): - if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = 16 - inputs[k] = v.to("cuda").repeat(*new_shape) - - -def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - chatglm_config = ChatGLMConfig( - num_layers=2, - vocab_size=20000, - use_cache=True, - multi_query_attention=True, - multi_query_group_num=2, - num_attention_heads=8, - hidden_size=1024, - ) - model = ChatGLMForConditionalGeneration(chatglm_config) - - engine = InferenceEngine( - tp_size=tp_size, - pp_size=pp_size, - model=model, - max_output_len=max_output_len, - micro_batch_size=micro_batch_size, - ) - output = engine.generate(inputs) - if dist.get_rank() == 0: - assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}" - - -@parameterize("tp_size", [1]) -@parameterize("pp_size", [2]) -@parameterize("max_output_len", [4]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [2]) -@parameterize("pp_size", [2]) -@parameterize("max_output_len", [4]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [2]) -@parameterize("pp_size", [1]) -@parameterize("max_output_len", [2]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [1]) -@parameterize("pp_size", [1]) -@parameterize("max_output_len", [2]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -def check_tp_pp_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_tp_pipeline_inference_test() - - -def check_tp_or_pp_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_tp_inference_test() - run_pipeline_inference_test() - - -def check_single_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_single_inference_test - - -@pytest.mark.skipif( - not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, - reason="kv-cache manager engine requires cuda version to be higher than 11.5", -) -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_pipeline_inference(): - spawn(check_tp_pp_inference, nprocs=4) - spawn(check_tp_or_pp_inference, nprocs=2) - spawn(check_single_inference, nprocs=1) - - -if __name__ == "__main__": - test_pipeline_inference() diff --git a/tests/test_infer/test_hybrid_llama.py b/tests/test_infer/test_hybrid_llama.py deleted file mode 100644 index 30b8b0a99..000000000 --- a/tests/test_infer/test_hybrid_llama.py +++ /dev/null @@ -1,126 +0,0 @@ -import importlib.util - -import pytest -import torch -import torch.distributed as dist -import transformers -from packaging import version - -import colossalai -from colossalai.inference import InferenceEngine -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn - -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") - -import importlib.util - -HAS_LIGHTLLM_KERNEL = True - -if importlib.util.find_spec("lightllm") is None: - HAS_LIGHTLLM_KERNEL = False - - -def data_gen(): - input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) - return dict(input_ids=input_ids, attention_mask=attention_mask) - - -inputs = data_gen() -for k, v in inputs.items(): - if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = 16 - inputs[k] = v.to("cuda").repeat(*new_shape) - - -def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - model = transformers.LlamaForCausalLM( - transformers.LlamaConfig( - vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4 - ) - ) - - engine = InferenceEngine( - tp_size=tp_size, - pp_size=pp_size, - model=model, - max_output_len=max_output_len, - micro_batch_size=micro_batch_size, - ) - output = engine.generate(inputs) - if dist.get_rank() == 0: - assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}" - - -@parameterize("tp_size", [1]) -@parameterize("pp_size", [2]) -@parameterize("max_output_len", [4]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [2]) -@parameterize("pp_size", [2]) -@parameterize("max_output_len", [4]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [2]) -@parameterize("pp_size", [1]) -@parameterize("max_output_len", [2]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [1]) -@parameterize("pp_size", [1]) -@parameterize("max_output_len", [2]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -def check_tp_pp_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_tp_pipeline_inference_test() - - -def check_tp_or_pp_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_tp_inference_test() - run_pipeline_inference_test() - - -def check_single_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_single_inference_test - - -@pytest.mark.skipif( - not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, - reason="kv-cache manager engine requires cuda version to be higher than 11.5", -) -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_pipeline_inference(): - spawn(check_tp_pp_inference, nprocs=4) - spawn(check_tp_or_pp_inference, nprocs=2) - spawn(check_single_inference, nprocs=1) - - -if __name__ == "__main__": - test_pipeline_inference() diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py deleted file mode 100644 index e87653172..000000000 --- a/tests/test_infer/test_kvcache_manager.py +++ /dev/null @@ -1,66 +0,0 @@ -import os - -import pytest -import torch -from packaging import version - -from colossalai.inference.kv_cache import MemoryManager -from colossalai.logging import disable_existing_loggers -from colossalai.testing import rerun_if_address_is_in_use, spawn - -BATCH_SIZE = 4 -INPUT_LEN = 16 -OUTPUT_LEN = 8 -LAYER_NUM = 4 -HEAD_NUM = 32 -HEAD_DIM = 128 - -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") - - -def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim): - os.environ["RANK"] = str(rank) - os.environ["LOCAL_RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(port) - disable_existing_loggers() - - size = batch_size * (input_len + output_len) - kvcache_manager = MemoryManager(size, torch.float16, head_num // world_size, head_dim, layer_num, rank) - key_buffers = kvcache_manager.key_buffer - value_buffers = kvcache_manager.value_buffer - assert len(key_buffers) == len(value_buffers) == layer_num - assert key_buffers[0].shape == value_buffers[0].shape - # required size exceeds the maximum allocated size - invalid_locs = kvcache_manager.alloc_contiguous(size + 1) - assert invalid_locs is None - # for prefill stage, allocation via alloc and alloc_contiguous should be the same - total_token_prefill = batch_size * input_len - prefill_locs = kvcache_manager.alloc(total_token_prefill) - kvcache_manager.free_all() - prefill_locs_contiguous = kvcache_manager.alloc_contiguous(total_token_prefill)[0] - assert torch.equal(prefill_locs, prefill_locs_contiguous) - assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill - kvcache_manager.alloc_contiguous(batch_size) - assert torch.all(kvcache_manager.mem_state[: total_token_prefill + batch_size] == False) - - -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_cache_manager_dist(): - spawn( - create_cache_manager, - 4, - batch_size=BATCH_SIZE, - input_len=INPUT_LEN, - output_len=OUTPUT_LEN, - layer_num=LAYER_NUM, - head_num=HEAD_NUM, - head_dim=HEAD_DIM, - ) - - -if __name__ == "__main__": - test_cache_manager_dist() diff --git a/tests/test_infer_ops/triton/test_bloom_context_attention.py b/tests/test_infer_ops/triton/test_bloom_context_attention.py deleted file mode 100644 index 7a6c218a6..000000000 --- a/tests/test_infer_ops/triton/test_bloom_context_attention.py +++ /dev/null @@ -1,52 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton import bloom_context_attn_fwd - from tests.test_infer_ops.triton.kernel_utils import torch_context_attention - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_bloom_context_attention(): - bs = 4 - head_num = 8 - seq_len = 1024 - head_dim = 64 - - query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - - max_input_len = seq_len - b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32) - b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32) - - for i in range(bs): - b_start[i] = i * seq_len - b_len[i] = seq_len - - o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda") - bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi) - - torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - - assert torch.allclose( - torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2 - ), "outputs from triton and torch are not matched" - - -if __name__ == "__main__": - test_bloom_context_attention() diff --git a/tests/test_infer_ops/triton/test_copy_kv_dest.py b/tests/test_infer_ops/triton/test_copy_kv_dest.py deleted file mode 100644 index 34e453f78..000000000 --- a/tests/test_infer_ops/triton/test_copy_kv_dest.py +++ /dev/null @@ -1,39 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_kv_cache_copy_op(): - B_NTX = 32 * 2048 - head_num = 8 - head_dim = 64 - - cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16) - dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32) - - dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16) - - copy_kv_cache_to_dest(cache, dest_index, dest_data) - - assert torch.allclose( - cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3 - ), "copy_kv_cache_to_dest outputs from triton and torch are not matched" - - -if __name__ == "__main__": - test_kv_cache_copy_op() diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py deleted file mode 100644 index 95fe50cf1..000000000 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ /dev/null @@ -1,50 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton import llama_context_attn_fwd - from tests.test_infer_ops.triton.kernel_utils import torch_context_attention - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_llama_context_attention(): - bs = 4 - head_num = 8 - seq_len = 1024 - head_dim = 64 - - query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - - max_input_len = seq_len - b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32) - b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32) - - for i in range(bs): - b_start[i] = i * seq_len - b_len[i] = seq_len - - o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len) - - torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - assert torch.allclose( - torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3 - ), "outputs from triton and torch are not matched" - - -if __name__ == "__main__": - test_llama_context_attention() diff --git a/tests/test_infer_ops/triton/test_self_attention_nonfusion.py b/tests/test_infer_ops/triton/test_self_attention_nonfusion.py deleted file mode 100644 index 9bdec8664..000000000 --- a/tests/test_infer_ops/triton/test_self_attention_nonfusion.py +++ /dev/null @@ -1,143 +0,0 @@ -import pytest -import torch -import torch.nn.functional as F -from packaging import version - -try: - import triton - - from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel - from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_qkv_matmul(): - qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16) - scale = 1.2 - head_size = 32 - batches = qkv.shape[0] - d_model = qkv.shape[-1] // 3 - num_of_heads = d_model // head_size - - q = qkv[:, :, :d_model] - k = qkv[:, :, d_model : d_model * 2] - - q = q.view(batches, -1, num_of_heads, head_size) - k = k.view(batches, -1, num_of_heads, head_size) - q_copy = q.clone() - k_copy = k.clone() - q = torch.transpose(q, 1, 2).contiguous() - k = torch.transpose(k, 1, 2).contiguous() - k = torch.transpose(k, 2, 3).contiguous() - - torch_ouput = torch.einsum("bnij,bnjk->bnik", q, k) - torch_ouput *= 1.2 - - q, k = q_copy, k_copy - batches, M, H, K = q.shape - N = k.shape[1] - score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype) - - grid = lambda meta: ( - batches, - H, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ) - - K = q.shape[3] - qkv_gemm_4d_kernel[grid]( - q, - k, - score_output, - M, - N, - K, - q.stride(0), - q.stride(2), - q.stride(1), - q.stride(3), - k.stride(0), - k.stride(2), - k.stride(3), - k.stride(1), - score_output.stride(0), - score_output.stride(1), - score_output.stride(2), - score_output.stride(3), - scale=scale, - # currently manually setting, later on we can use auto-tune config to match best setting - BLOCK_SIZE_M=64, - BLOCK_SIZE_N=32, - BLOCK_SIZE_K=32, - GROUP_SIZE_M=8, - ) - - check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5) - assert check is True, "the outputs of triton and torch are not matched" - - -def self_attention_compute_using_torch(qkv, input_mask, scale, head_size): - batches = qkv.shape[0] - d_model = qkv.shape[-1] // 3 - num_of_heads = d_model // head_size - - q = qkv[:, :, :d_model] - k = qkv[:, :, d_model : d_model * 2] - v = qkv[:, :, d_model * 2 :] - q = q.view(batches, -1, num_of_heads, head_size) - k = k.view(batches, -1, num_of_heads, head_size) - v = v.view(batches, -1, num_of_heads, head_size) - - q = torch.transpose(q, 1, 2).contiguous() - k = torch.transpose(k, 1, 2).contiguous() - v = torch.transpose(v, 1, 2).contiguous() - - k = torch.transpose(k, -1, -2).contiguous() - - score_output = torch.einsum("bnij,bnjk->bnik", q, k) - score_output *= scale - - softmax_output = F.softmax(score_output, dim=-1) - res = torch.einsum("bnij,bnjk->bnik", softmax_output, v) - res = torch.transpose(res, 1, 2) - res = res.contiguous() - - return res.view(batches, -1, d_model), score_output, softmax_output - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_self_atttention_test(): - qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16) - data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch( - qkv.clone(), input_mask=None, scale=1.2, head_size=32 - ) - - data_output_triton = self_attention_compute_using_triton( - qkv.clone(), - alibi=None, - head_size=32, - scale=1.2, - input_mask=None, - layer_past=None, - use_flash=False, - triangular=True, - ) - - check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2) - assert check is True, "the triton output is not matched with torch output" - - -if __name__ == "__main__": - test_qkv_matmul() - test_self_atttention_test() diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py deleted file mode 100644 index 4ee1a5fb1..000000000 --- a/tests/test_infer_ops/triton/test_token_attn_fwd.py +++ /dev/null @@ -1,72 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - - -import importlib.util - -HAS_LIGHTLLM_KERNEL = True - -if importlib.util.find_spec("lightllm") is None: - HAS_LIGHTLLM_KERNEL = False - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) >= version.parse("11.6") - - -def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): - xq = xq.view(bs, 1, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - xv = xv.view(bs, seqlen, num_head, head_dim) - - logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5) - prob = torch.softmax(logics, dim=1) - prob = prob.view(bs, seqlen, num_head, 1) - - return torch.sum(prob * xv, dim=1, keepdim=False) - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_LIGHTLLM_KERNEL, - reason="triton requires cuda version to be higher than 11.4 or not install lightllm", -) -def test(): - Z, head_num, seq_len, head_dim = 22, 112 // 8, 2048, 128 - dtype = torch.float16 - q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda") - - max_kv_cache_len = seq_len - kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") - kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - - kv_cache_seq_len[:] = seq_len - kv_cache_start_loc[0] = 0 - kv_cache_start_loc[1] = seq_len - kv_cache_start_loc[2] = 2 * seq_len - kv_cache_start_loc[3] = 3 * seq_len - - for i in range(Z): - kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda") - - token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi) - torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim) - - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test() diff --git a/tests/test_infer_ops/triton/test_token_softmax.py b/tests/test_infer_ops/triton/test_token_softmax.py deleted file mode 100644 index 1f97f1674..000000000 --- a/tests/test_infer_ops/triton/test_token_softmax.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.token_attention_kernel import token_attn_softmax_fwd - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_softmax(): - import torch - - batch_size, seq_len, head_num, head_dim = 4, 1025, 12, 128 - - dtype = torch.float16 - - Logics = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) - ProbOut = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - - kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - - for i in range(batch_size): - kv_cache_start_loc[i] = i * seq_len - kv_cache_seq_len[i] = seq_len - - token_attn_softmax_fwd(Logics, kv_cache_start_loc, kv_cache_seq_len, ProbOut, seq_len) - - torch_out = Logics.reshape(head_num * batch_size, -1).softmax(-1).reshape(head_num, batch_size * seq_len) - o = ProbOut - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test_softmax()