From be396ad6cc102fa610731291bf28e531a5641c7a Mon Sep 17 00:00:00 2001 From: Steve Luo <36296769+SunflowerAries@users.noreply.github.com> Date: Thu, 18 Apr 2024 16:45:07 +0800 Subject: [PATCH] [Inference/Kernel] Add Paged Decoding kernel, sequence split within the same thread block (#5531) * feat flash decoding for paged attention * refactor flashdecodingattention * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../modeling/models/nopadding_llama.py | 13 + .../benchmark_ops/benchmark_decoding_attn.py | 15 +- .../benchmark_flash_decoding_attention.py | 173 +++++++++ .../csrc/cuda/attention/attention_utils.h | 206 ++++++++++ .../cuda/flash_decoding_attention_kernel.cu | 353 ++++++++++++++++++ extensions/csrc/cuda/funcs/binary_functor.h | 222 ++++++++--- extensions/csrc/cuda/funcs/cast_functor.h | 154 ++++++-- extensions/csrc/cuda/funcs/ternary_functor.h | 212 +++++++++++ extensions/csrc/cuda/funcs/unary_functor.h | 36 +- extensions/csrc/cuda/pybind/inference.cpp | 19 + extensions/csrc/cuda/rms_layernorm_kernel.cu | 172 ++------- extensions/csrc/cuda/utils/vec_type_traits.h | 61 ++- extensions/inference/inference_ops_cuda.py | 1 + .../cuda/test_flash_decoding_attention.py | 274 ++++++++++++++ .../test_ops/triton/kernel_utils.py | 65 ++++ 15 files changed, 1765 insertions(+), 211 deletions(-) create mode 100644 examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py create mode 100644 extensions/csrc/cuda/attention/attention_utils.h create mode 100644 extensions/csrc/cuda/flash_decoding_attention_kernel.cu create mode 100644 extensions/csrc/cuda/funcs/ternary_functor.h create mode 100644 tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 010abc1db..5ef576e51 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -437,6 +437,19 @@ class NopadLlamaAttention(LlamaAttention): block_tables, high_precision, ) + # inference_ops.flash_decoding_attention( + # attn_output, + # query_states, + # k_cache, + # v_cache, + # sequence_lengths, + # block_tables, + # block_size, + # kv_seq_len, + # fd_inter_tensor.mid_output, + # fd_inter_tensor.mid_output_lse, + # sm_scale, + # ) else: if is_verifier: rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) diff --git a/examples/inference/benchmark_ops/benchmark_decoding_attn.py b/examples/inference/benchmark_ops/benchmark_decoding_attn.py index ae68aedf5..ae104c807 100644 --- a/examples/inference/benchmark_ops/benchmark_decoding_attn.py +++ b/examples/inference/benchmark_ops/benchmark_decoding_attn.py @@ -4,8 +4,8 @@ from colossalai.kernel.triton import flash_decoding_attention from colossalai.utils import get_current_device from tests.test_infer.test_ops.triton.kernel_utils import ( convert_kv_unpad_to_padded, + create_attention_mask, generate_caches_and_block_tables_v2, - prepare_padding_mask, torch_attn_ref, ) from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data @@ -67,9 +67,18 @@ def bench_kernel( if provider == "torch": k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_seq_len_in_b) v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_seq_len_in_b) - torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, max_seq_len_in_b, q.device) + torch_padding_mask = create_attention_mask(kv_lengths, bsz, Q_LEN, max_seq_len_in_b, q.device) fn = lambda: torch_attn_ref( - q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM + q, + k_torch, + v_torch, + torch_padding_mask, + bsz, + Q_LEN, + max_seq_len_in_b, + num_attn_heads, + num_kv_heads, + HEAD_DIM, ) ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) if provider == "triton": diff --git a/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py b/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py new file mode 100644 index 000000000..e33d9a9dc --- /dev/null +++ b/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py @@ -0,0 +1,173 @@ +import torch + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import flash_decoding_attention +from colossalai.utils import get_current_device +from tests.test_infer.test_ops.triton.kernel_utils import ( + generate_caches_and_block_tables_v2, + generate_caches_and_block_tables_vllm, +) + +try: + import triton # noqa +except ImportError: + print("please install triton from https://github.com/openai/triton") + +inference_ops = InferenceOpsLoader().load() + +# Triton benchmark plot attributions +configs = [ + triton.testing.Benchmark( + x_names=["MAX_NUM_BLOCKS_PER_SEQ"], + x_vals=[2**i for i in range(3, 8)], + line_arg="provider", + line_vals=[ + "vllm_paged_decoding_attention", + "triton_flash_decoding_attention", + "cuda_flash_decoding_attention", + ], + line_names=[ + "vllm_paged_decoding_attention", + "triton_flash_decoding_attention", + "cuda_flash_decoding_attention", + ], + styles=[("red", "-"), ("blue", "-"), ("yellow", "-")], + ylabel="ms", + plot_name=f"FlashDecodingAttention benchmarking results", + args={"BATCH_SIZE": 16, "BLOCK_SIZE": 32, "HEAD_SIZE": 128, "KV_GROUP_NUM": 2}, + ) +] + + +def prepare_data( + BATCH_SIZE: int, + HEAD_SIZE: int, + NUM_ATTN_HEADS: int, + NUM_KV_HEADS: int, + MAX_SEQ_LEN: int, + dtype=torch.float16, + device="cuda", +): + # Use the provided maximum sequence length for each sequence when testing with teh same context length, + # otherwise generate random context lengths. + # returns + # q [BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE] + # k_unpad/v_unpad [num_tokens, NUM_KV_HEADS, HEAD_SIZE] + kv_lengths = torch.randint(low=1, high=MAX_SEQ_LEN, size=(BATCH_SIZE,), dtype=torch.int32, device=device) + num_tokens = torch.sum(kv_lengths).item() + + q_size = (BATCH_SIZE, 1, NUM_ATTN_HEADS, HEAD_SIZE) + q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2) + kv_size = (num_tokens, 2 * NUM_KV_HEADS, HEAD_SIZE) + kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + k_unpad, v_unpad = torch.split(kv_unpad, [NUM_KV_HEADS, NUM_KV_HEADS], dim=-2) + + return q, k_unpad, v_unpad, kv_lengths + + +@triton.testing.perf_report(configs) +def benchmark_flash_decoding_attention( + provider: str, + BATCH_SIZE: int, + BLOCK_SIZE: int, + MAX_NUM_BLOCKS_PER_SEQ: int, + HEAD_SIZE: int, + KV_GROUP_NUM: int, +): + try: + from vllm._C import ops as vllm_ops + except ImportError: + raise ImportError("Please install vllm from https://github.com/vllm-project/vllm") + + warmup = 10 + rep = 1000 + + dtype = torch.float16 + + NUM_ATTN_HEADS = 16 + + NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM + assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads." + MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ + device = get_current_device() + + q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( + BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device + ) + + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( + k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device + ) + + vllm_k_cache, vllm_v_cache, _ = generate_caches_and_block_tables_vllm( + k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device + ) + + block_tables = block_tables.to(device=device) + max_seq_len_across_batch = kv_seq_lengths.max().item() + kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE + output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device) + sm_scale = 1.0 / (HEAD_SIZE**0.5) + + mid_output = torch.empty( + size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device + ) + mid_output_lse = torch.empty( + size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device + ) + + if provider == "vllm_paged_decoding_attention": + alibi_slopes = None + fn = lambda: vllm_ops.paged_attention_v1( + output, + q.squeeze(2), + vllm_k_cache, + vllm_v_cache, + NUM_KV_HEADS, + sm_scale, + block_tables, + kv_seq_lengths, + BLOCK_SIZE, + max_seq_len_across_batch, + alibi_slopes, + "auto", + ) + elif provider == "triton_flash_decoding_attention": + fn = lambda: flash_decoding_attention( + q.squeeze(2), + k_cache, + v_cache, + kv_seq_lengths, + block_tables, + BLOCK_SIZE, + max_seq_len_across_batch, + output, + mid_output, + mid_output_lse, + sm_scale=sm_scale, + kv_group_num=KV_GROUP_NUM, + ) # [bsz, 1, num_heads, head_dim] + elif provider == "cuda_flash_decoding_attention": + fn = lambda: inference_ops.flash_decoding_attention( + output, + q.squeeze(2), + k_cache, + v_cache, + kv_seq_lengths, + block_tables, + BLOCK_SIZE, + max_seq_len_across_batch, + mid_output, + mid_output_lse, + sm_scale, + ) + else: + raise ValueError("Undefined provider.") + + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + + return ms + + +if __name__ == "__main__": + benchmark_flash_decoding_attention.run(save_path=".", print_data=True) diff --git a/extensions/csrc/cuda/attention/attention_utils.h b/extensions/csrc/cuda/attention/attention_utils.h new file mode 100644 index 000000000..c55033636 --- /dev/null +++ b/extensions/csrc/cuda/attention/attention_utils.h @@ -0,0 +1,206 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2024, The Colossal-AI team. + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include "../funcs/binary_functor.h" +#include "../funcs/cast_functor.h" +#include "../funcs/ternary_functor.h" +#include "../funcs/unary_functor.h" +#include "../utils/vec_type_traits.h" + +namespace colossalAI { +namespace cuda { +namespace attention { + +using colossalAI::cuda::funcs::BinaryOpFunctor; +using colossalAI::cuda::funcs::BinaryOpType; +using colossalAI::cuda::funcs::TernaryOpFunctor; +using colossalAI::cuda::funcs::TernaryOpType; +using colossalAI::cuda::funcs::UnaryOpFunctor; +using colossalAI::cuda::funcs::UnaryOpType; +using colossalAI::cuda::utils::FloatVecTypeTrait; + +#define WARP_SIZE 32 +#define VEC_SIZE_8 8 + +#define SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#define SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) + +// Q*K^T operation. +template +inline __device__ float qk_dot_(const VecT (&q)[N], const VecT (&k)[N]) { + using A_vec = typename FloatVecTypeTrait::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + BinaryOpFunctor mul_vect; + UnaryOpFunctor sum_vect; + TernaryOpFunctor fma; + + A_vec qk_vec = mul_vect(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ii++) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum_vect(qk_vec); +#pragma unroll + for (int mask = (NUM_THREADS_PER_TOKEN >> 1); mask > 0; mask >>= 1) { + qk += SHFL_XOR_SYNC(qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const VecT (&q)[N], const VecT (&k)[N]) { + return qk_dot_(q, k); + } +}; + +template +inline __device__ float block_max(float* red_smem, float max) { + int warp = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + +// Perform reduction across the threads in the same warp to get the max value +// for each warp, the 1st out of NUM_THREADS_PER_TOKEN thread already has the +// max value among every NUM_THREADS_PER_TOKEN threads. +#pragma unroll + for (int mask = (WARP_SIZE >> 1); mask >= NUM_THREADS_PER_TOKEN; mask >>= 1) { + max = fmaxf(max, SHFL_XOR_SYNC(max, mask)); + } + + if (lane == 0) red_smem[warp] = max; + __syncthreads(); + + // The warps compute the final maxs. + max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; + +// Parallel reduction of all tokens from the same sequence inside the warp. +#pragma unroll + for (int mask = (NUM_WARPS >> 1); mask > 0; mask >>= 1) { + max = fmaxf(max, SHFL_XOR_SYNC(max, mask)); + } + + // Broadcast to other threads. + return SHFL_SYNC(max, 0); +} + +// here we need another block_sum instead of using block_reduce +// since we need manage shared memory in a explicit way +template +inline __device__ float block_sum(float* red_smem, float sum) { + int warp = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + +// Compute the sum per warp. +#pragma unroll + for (int mask = (WARP_SIZE >> 1); mask > 0; mask >>= 1) { + sum += SHFL_XOR_SYNC(sum, mask); + } + + if (lane == 0) red_smem[warp] = sum; + __syncthreads(); + + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + +// Parallel reduction of all tokens from the same sequence inside the warp. +#pragma unroll + for (int mask = (NUM_WARPS >> 1); mask > 0; mask >>= 1) { + sum += SHFL_XOR_SYNC(sum, mask); + } + + // Broadcast to other threads. + return SHFL_SYNC(sum, 0); +} + +// here VecT is a vector of float, whose size is N +template +inline __device__ void block_sum(float* red_smem, VecT& acc) { + float* acc_ptr = reinterpret_cast(&acc); + int warp = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + +#pragma unroll + for (int i = 0; i < N; i++) { +#pragma unroll + for (int mask = (WARP_SIZE >> 1); mask >= NUM_THREADS_PER_GROUP; + mask >>= 1) { + acc_ptr[i] += SHFL_XOR_SYNC(acc_ptr[i], mask); + } + } + +#pragma unroll + for (int limit = NUM_WARPS; limit > 1; limit >>= 1) { + int mid = limit >> 1; + if (warp >= mid && warp < limit) { + float* dst = red_smem + (warp - mid) * N * NUM_THREADS_PER_GROUP; + if (lane < NUM_THREADS_PER_GROUP) { + if constexpr (N == VEC_SIZE_8) { + VecT* vdst = &((reinterpret_cast(dst))[lane]); + (reinterpret_cast(vdst))[0] = + (reinterpret_cast(acc_ptr))[0]; + (reinterpret_cast(vdst))[1] = + (reinterpret_cast(acc_ptr))[1]; + } else { + (reinterpret_cast(dst))[lane] = acc; + } + } + } + __syncthreads(); + + if (warp < mid) { + float* src = red_smem + warp * N * NUM_THREADS_PER_GROUP; + VecT src_reg; + if (lane < NUM_THREADS_PER_GROUP) { + float* src_ptr = reinterpret_cast(&src_reg); + if constexpr (N == VEC_SIZE_8) { + VecT* vsrc = &((reinterpret_cast(src))[lane]); + (reinterpret_cast(src_ptr))[0] = + (reinterpret_cast(vsrc))[0]; + (reinterpret_cast(src_ptr))[1] = + (reinterpret_cast(vsrc))[1]; + } else { + src_reg = (reinterpret_cast(src))[lane]; + } +#pragma unroll + for (int j = 0; j < N; j++) { + acc_ptr[j] += src_ptr[j]; + } + } + } + __syncthreads(); + } +} + +#undef SHFL_SYNC +#undef SHFL_XOR_SYNC + +} // namespace attention +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/cuda/flash_decoding_attention_kernel.cu b/extensions/csrc/cuda/flash_decoding_attention_kernel.cu new file mode 100644 index 000000000..69b50616b --- /dev/null +++ b/extensions/csrc/cuda/flash_decoding_attention_kernel.cu @@ -0,0 +1,353 @@ +/*This code adapted from vllm: + * https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu + * with different kvcache layout. */ + +#include +#include +#include +#include + +#include "../common/micros.h" +#include "funcs/cast_functor.h" +#include "funcs/ternary_functor.h" +#include "funcs/binary_functor.h" +#include "utils/vec_type_traits.h" +#include "attention/attention_utils.h" + +#define WARP_SIZE 32 +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +// 2^n => 2^n, 2^n-d => 2^(n-1) +#define ROUND_DOWN_HIGHEST_POWER_OF_TWO(x) (nextHighestPowerOf2((x - (x + 1) / 2 + 1))) + +// a bit magic, you can ask chatgpt for help +// 2^n => 2^n, 2^n-d => 2^n +constexpr unsigned int nextHighestPowerOf2(unsigned int v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} + +using colossalAI::cuda::funcs::BinaryOpType; +using colossalAI::cuda::funcs::CastFunctor; +using colossalAI::cuda::funcs::TernaryOpFunctor; +using colossalAI::cuda::funcs::TernaryOpType; +using colossalAI::cuda::funcs::zero; +using colossalAI::cuda::utils::VecTypeTrait; +using colossalAI::cuda::utils::FloatVecTypeTrait; +using namespace colossalAI::cuda::attention; + + +// We only support head size of { 64, 128, 256 } +// models like Phi-2, whose head size is 80, is not supported right now +template +__global__ void flash_decoding_attention_kernel( + scalar_t* __restrict__ out, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_tokens, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, block_size, head_size] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size] + const int* __restrict__ context_lens, // [num_tokens] + const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq] + const int max_seq_len, + const int num_kv_heads, + const float scale, + const int max_num_blocks_per_seq, + const int q_stride, // num_heads * head_size + const int kv_block_stride, + const int kv_head_stride) { + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int thread_idx = threadIdx.x; + const int lane = thread_idx % WARP_SIZE; + const int warp_idx = thread_idx / WARP_SIZE; + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + constexpr int Q_SHARED_SIZE = (HEAD_SIZE * sizeof(scalar_t)) / sizeof(float4); + // here thread_group does not determine the number of threads responsible for a key + // but only the VEC_SIZE of each thread + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((HEAD_SIZE / THREAD_GROUP_SIZE)), sizeof(float4) / sizeof(scalar_t)); + constexpr int NUM_VECS_PER_TOKEN = HEAD_SIZE / VEC_SIZE; + constexpr int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE); + constexpr int NUM_ROUNDS_PER_TOKEN = NUM_VECS_PER_TOKEN / NUM_THREADS_PER_TOKEN; + constexpr int WARP_STRIDE = WARP_SIZE * NUM_ROUNDS_PER_TOKEN; + + using K_vec = typename VecTypeTrait::Type; + using V_vec = typename VecTypeTrait::Type; + using L_vec = typename VecTypeTrait::Type; + using Float_vec = typename FloatVecTypeTrait::Type; + + const int context_len = context_lens[seq_idx]; + const int thread_group_offset = thread_idx % NUM_THREADS_PER_TOKEN; + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + __shared__ float4 q_shared[Q_SHARED_SIZE]; + __shared__ float red_shared_mem[2 * NUM_WARPS]; + extern __shared__ char shared_mem[]; + float* logits = reinterpret_cast(shared_mem); + float* out_shared_mem = reinterpret_cast(shared_mem); + float qk_max = -FLT_MAX; + + const float4* q_ptr = reinterpret_cast(q + seq_idx * q_stride + head_idx * HEAD_SIZE); + #pragma unroll + for (int idx = thread_idx; idx < Q_SHARED_SIZE; idx += blockDim.x) { + q_shared[idx] = q_ptr[idx]; + } + __syncthreads(); + + scalar_t* q_shared_ptr = reinterpret_cast(q_shared); + // each warp access a whole block + for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) { + const int64_t physical_block_number = static_cast(block_table[block_idx]); + #pragma unroll + for (int idx = lane; idx < BLOCK_SIZE * NUM_VECS_PER_TOKEN; idx += WARP_STRIDE) { + const int token_idx = block_idx * BLOCK_SIZE + idx / NUM_VECS_PER_TOKEN; + const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + idx * VEC_SIZE; + + K_vec k_vecs[NUM_ROUNDS_PER_TOKEN]; + K_vec q_vecs[NUM_ROUNDS_PER_TOKEN]; + + // we must calculate at least one row of hidden vectors + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + k_vecs[i] = (reinterpret_cast(k_ptr))[i * WARP_SIZE]; + q_vecs[i] = (reinterpret_cast(q_shared_ptr))[(idx + i * WARP_SIZE) % NUM_VECS_PER_TOKEN]; + } + + float qk = scale * Qk_dot::dot(q_vecs, k_vecs); + + if (thread_group_offset == 0) { + const bool mask = token_idx >= context_len; + logits[token_idx] = mask ? 0.f : qk; + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // there exists a __syncthreads within this function + qk_max = block_max(red_shared_mem, qk_max); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + + exp_sum = block_sum(&red_shared_mem[NUM_WARPS], exp_sum); + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + Float_vec accs[NUM_ROUNDS_PER_TOKEN]; + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + zero(accs[i]); + } + + V_vec zero_value; + zero(zero_value); + for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) { + const int64_t physical_block_number = static_cast(block_table[block_idx]); + scalar_t logit; + + #pragma unroll + for (int idx = lane; idx < BLOCK_SIZE * NUM_VECS_PER_TOKEN; idx += WARP_STRIDE) { + const int token_idx = block_idx * BLOCK_SIZE + idx / NUM_VECS_PER_TOKEN; + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + idx * VEC_SIZE; + + V_vec v_vecs[NUM_ROUNDS_PER_TOKEN]; + + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + v_vecs[i] = (reinterpret_cast(v_ptr))[i * WARP_SIZE]; + } + + if (token_idx >= context_len) { + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + v_vecs[i] = zero_value; + } + } + + logit = CastFunctor()(logits[token_idx]); + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + accs[i] = TernaryOpFunctor()(logit, v_vecs[i], accs[i]); + } + } + } + + // must insert a sync since both logits and out_shared_mem occupy the same buffer space + __syncthreads(); + + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + block_sum(out_shared_mem, accs[i]); + } + + scalar_t* out_ptr = out + seq_idx * q_stride + head_idx * HEAD_SIZE; + L_vec out_reg; + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + if (thread_idx < NUM_THREADS_PER_TOKEN) { + out_reg = CastFunctor()(accs[i]); + (reinterpret_cast(out_ptr))[thread_idx + i * NUM_THREADS_PER_TOKEN] = out_reg; + } + } +} + +#define LAUNCH_FLASH_DECODING_ATTENTION_V1(HEAD_SIZE) \ + cudaFuncSetAttribute( \ + ((void*)flash_decoding_attention_kernel), \ + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + flash_decoding_attention_kernel \ + <<>>( \ + reinterpret_cast(out.data_ptr()), \ + reinterpret_cast(query.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + context_lens.data_ptr(), \ + block_tables.data_ptr(), \ + max_context_len, \ + num_kv_heads, \ + scale, \ + max_num_blocks_per_seq, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); + +template< + typename T, + typename CACHE_T, + int BLOCK_SIZE, + int NUM_THREADS = 128> +void flash_decoding_attention_v1_launcher( + torch::Tensor& out, // [num_tokens, num_heads, head_size] + torch::Tensor& query, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& context_lens, // [num_tokens] + torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq] + int max_context_len, + float scale) { + int num_tokens = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int num_kv_heads = key_cache.size(1); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + const int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((head_size / THREAD_GROUP_SIZE)), sizeof(float4) / sizeof(T)); + const int NUM_VECS_PER_TOKEN = head_size / VEC_SIZE; + const int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE); + + int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_context_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * NUM_THREADS_PER_TOKEN * VEC_SIZE * sizeof(float); + // Keep that in sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_tokens, 1); + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. + case 64: + LAUNCH_FLASH_DECODING_ATTENTION_V1(64); + break; + case 128: + LAUNCH_FLASH_DECODING_ATTENTION_V1(128); + break; + case 256: + LAUNCH_FLASH_DECODING_ATTENTION_V1(256); + break; + default: + AT_ERROR("head size must be 64, 128, 256"); + break; + } +} + +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE) \ + flash_decoding_attention_v1_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + context_lens, \ + block_tables, \ + max_context_len, \ + scale); + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER(T, CACHE_T, 8); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER(T, CACHE_T, 16); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER(T, CACHE_T, 32); \ + break; \ + default: \ + AT_ERROR("block size must be 8, 16, 32"); \ + break; \ + } + +void flash_decoding_attention( + torch::Tensor& out, // [num_tokens, num_heads, head_size] + torch::Tensor& query, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& context_lens, // [num_tokens] + torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq] + int block_size, + int max_context_len, + torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size] + torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions] + float scale) { + switch (query.scalar_type()) { + case at::ScalarType::Float: + CALL_V1_LAUNCHER_BLOCK_SIZE(float, float); + break; + case at::ScalarType::Half: + CALL_V1_LAUNCHER_BLOCK_SIZE(half, half); + break; + case at::ScalarType::BFloat16: + CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16); + break; + default: + AT_ERROR("Unsupported data type: ", toString(query.scalar_type())); + } +} + + +#undef LAUNCH_FLASH_DECODING_ATTENTION_V1 +#undef CALL_V1_LAUNCHER +#undef CALL_V1_LAUNCHER_BLOCK_SIZE diff --git a/extensions/csrc/cuda/funcs/binary_functor.h b/extensions/csrc/cuda/funcs/binary_functor.h index 2f26e7197..e5a68d938 100644 --- a/extensions/csrc/cuda/funcs/binary_functor.h +++ b/extensions/csrc/cuda/funcs/binary_functor.h @@ -8,11 +8,20 @@ #include #include "../utils/micros.h" +#include "../utils/vec_type_traits.h" +#include "cast_functor.h" namespace colossalAI { namespace cuda { namespace funcs { +using utils::bfloat164; +using utils::bfloat168; +using utils::float4_; +using utils::float8_; +using utils::half4; +using utils::half8; + enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin }; // Note(LiuYang): This file provides base math operation for data type @@ -22,73 +31,182 @@ enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin }; template struct BinaryOpFunctor; -#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \ - FUNCTION_MODIFIER, ARGS...) \ - template \ - struct BinaryOpFunctor \ - : public std::binary_function { \ - FUNCTION_MODIFIER T operator()(T lhs, T rhs) { return STMT; } \ +#define STMTS_WRAPPER(...) __VA_ARGS__ + +#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( \ + LT, RT, RET, BINARY_OP_TYPE, FUNCTION_MODIFIER, STMTS, ARGS...) \ + template \ + struct BinaryOpFunctor \ + : public std::binary_function { \ + FUNCTION_MODIFIER RET operator()(LT lhs, RT rhs) STMTS \ }; -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kAdd, lhs + rhs, - HOSTDEVICE, typename T) -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMinus, lhs - rhs, - HOSTDEVICE, typename T) -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMul, lhs* rhs, - HOSTDEVICE, typename T) -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kDiv, lhs / rhs, - HOSTDEVICE, typename T) -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMax, max(lhs, rhs), - HOSTDEVICE, typename T) -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMin, min(lhs, rhs), - HOSTDEVICE, typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kAdd, HOSTDEVICE, + STMTS_WRAPPER({ return lhs + rhs; }), + typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMinus, + HOSTDEVICE, + STMTS_WRAPPER({ return lhs - rhs; }), + typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMul, HOSTDEVICE, + STMTS_WRAPPER({ return lhs * rhs; }), + typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kDiv, HOSTDEVICE, + STMTS_WRAPPER({ return lhs / rhs; }), + typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMax, HOSTDEVICE, + STMTS_WRAPPER({ return max(lhs, rhs); }), + typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMin, HOSTDEVICE, + STMTS_WRAPPER({ return min(lhs, rhs); }), + typename T) -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kAdd, - __hadd(lhs, rhs), DEVICE) -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kAdd, - __hadd2(lhs, rhs), DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kAdd, + DEVICE, STMTS_WRAPPER({ + return __hadd(lhs, rhs); + })) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, half2, half2, BinaryOpType::kAdd, + DEVICE, STMTS_WRAPPER({ + return __hadd2(lhs, rhs); + })) #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd, - __hadd(lhs, rhs), DEVICE) -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kAdd, - __hadd2(lhs, rhs), DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16, + __nv_bfloat16, BinaryOpType::kAdd, + DEVICE, STMTS_WRAPPER({ + return __hadd(lhs, rhs); + })) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162, + __nv_bfloat162, BinaryOpType::kAdd, + DEVICE, STMTS_WRAPPER({ + return __hadd2(lhs, rhs); + })) #else -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd, - __float2bfloat16(__bfloat162float(lhs) + - __bfloat162float(rhs)), - DEVICE) COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( - __nv_bfloat162, BinaryOpType::kAdd, - __floats2bfloat162_rn(__low2float(lhs) + __low2float(rhs), - __high2float(lhs) + __high2float(rhs)), - DEVICE) -#endif + __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kAdd, DEVICE, + STMTS_WRAPPER({ + return __float2bfloat16(__bfloat162float(lhs) + __bfloat162float(rhs)); + })) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kAdd, DEVICE, + STMTS_WRAPPER({ + return __floats2bfloat162_rn(__low2float(lhs) + __low2float(rhs), + __high2float(lhs) + __high2float(rhs)); + })) +#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */ -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kMul, - __hmul(lhs, rhs), DEVICE) -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kMul, - __hmul2(lhs, rhs), DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kMul, + DEVICE, STMTS_WRAPPER({ + return __hmul(lhs, rhs); + })) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, half2, half2, BinaryOpType::kMul, + DEVICE, STMTS_WRAPPER({ + return __hmul2(lhs, rhs); + })) #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul, - __hmul(lhs, rhs), DEVICE) -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kMul, - __hmul2(lhs, rhs), DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16, + __nv_bfloat16, BinaryOpType::kMul, + DEVICE, STMTS_WRAPPER({ + return __hmul(lhs, rhs); + })) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162, + __nv_bfloat162, BinaryOpType::kMul, + DEVICE, STMTS_WRAPPER({ + return __hmul2(lhs, rhs); + })) #else -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul, - __float2bfloat16(__bfloat162float(lhs) * - __bfloat162float(rhs)), - DEVICE) COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( - __nv_bfloat162, BinaryOpType::kMul, - __floats2bfloat162_rn(__low2float(lhs) * __low2float(rhs), - __high2float(lhs) * __high2float(rhs)), - DEVICE) -#endif + __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kMul, DEVICE, + STMTS_WRAPPER({ + return __float2bfloat16(__bfloat162float(lhs) * __bfloat162float(rhs)); + })) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kMul, DEVICE, + STMTS_WRAPPER({ + return __floats2bfloat162_rn(__low2float(lhs) * __low2float(rhs), + __high2float(lhs) * __high2float(rhs)); + })) +#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */ + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + float2, float2, float2, BinaryOpType::kMul, DEVICE, + STMTS_WRAPPER({ return make_float2(lhs.x * rhs.x, lhs.y * rhs.y); })) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(float4, float4, float4, + BinaryOpType::kMul, DEVICE, + STMTS_WRAPPER({ + return make_float4( + lhs.x * rhs.x, lhs.y * rhs.y, + lhs.z * rhs.z, lhs.w * rhs.w); + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat162, __nv_bfloat162, float2, BinaryOpType::kMul, DEVICE, + STMTS_WRAPPER({ + CastFunctor<__nv_bfloat162, float2> cast; + BinaryOpFunctor mul; + float2 fa = cast(lhs); + float2 fb = cast(rhs); + return mul(fa, fb); + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + bfloat164, bfloat164, float4_, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({ + float4_ fc; + BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, + BinaryOpType::kMul> + mul; + fc.x = mul(lhs.x, rhs.x); + fc.y = mul(lhs.y, rhs.y); + return fc; + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + bfloat168, bfloat168, float8_, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({ + float8_ fc; + BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, + BinaryOpType::kMul> + mul; + fc.x = mul(lhs.x, rhs.x); + fc.y = mul(lhs.y, rhs.y); + fc.z = mul(lhs.z, rhs.z); + fc.w = mul(lhs.w, rhs.w); + return fc; + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + half2, half2, float2, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({ + CastFunctor cast; + BinaryOpFunctor mul; + float2 fa = cast(lhs); + float2 fb = cast(rhs); + return mul(fa, fb); + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + half4, half4, float4_, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({ + float4_ fc; + BinaryOpFunctor mul; + fc.x = mul(lhs.x, rhs.x); + fc.y = mul(lhs.y, rhs.y); + return fc; + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + half8, half8, float8_, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({ + float8_ fc; + BinaryOpFunctor mul; + fc.x = mul(lhs.x, rhs.x); + fc.y = mul(lhs.y, rhs.y); + fc.z = mul(lhs.z, rhs.z); + fc.w = mul(lhs.w, rhs.w); + return fc; + })) #undef COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION +#undef STMTS_WRAPPER + } // namespace funcs } // namespace cuda } // namespace colossalAI diff --git a/extensions/csrc/cuda/funcs/cast_functor.h b/extensions/csrc/cuda/funcs/cast_functor.h index 05fffb766..d78ca4af2 100644 --- a/extensions/csrc/cuda/funcs/cast_functor.h +++ b/extensions/csrc/cuda/funcs/cast_functor.h @@ -8,6 +8,7 @@ #include #include "../utils/micros.h" +#include "../utils/vec_type_traits.h" // Note(LiuYang): This file provides base math operation for data type // include POD and cuda built-in type such as half and __nv_bfloat16 @@ -16,39 +17,150 @@ namespace colossalAI { namespace cuda { namespace funcs { +using utils::bfloat164; +using utils::bfloat168; +using utils::float4_; +using utils::float8_; +using utils::half4; +using utils::half8; + template struct CastFunctor : public std::unary_function { HOSTDEVICE To operator()(From val) { return static_cast(val); } }; -#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMT, \ +#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMTS, \ FUNCTION_MODIFIER) \ template <> \ struct CastFunctor : public std::unary_function { \ - FUNCTION_MODIFIER TO operator()(FROM val) { return STMT; } \ + FUNCTION_MODIFIER TO operator()(FROM val) STMTS \ }; -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, make_float2(val.x, val.y), - DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + int2, float2, { return make_float2(val.x, val.y); }, DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float, float2, { return make_float2(val, val); }, DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, make_float2(val, val), - DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half, __float2half(val), DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16, - __float2bfloat16(val), DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat162, - __float2bfloat162_rn(val), DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, __float2half2_rn(val), - DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + half2, float2, { return __half22float2(val); }, DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float2, half2, { return __float22half2_rn(val); }, DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float, half, { return __float2half_rn(val); }, DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float, half2, { return __float2half2_rn(val); }, DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + half, half2, { return __half2half2(val); }, DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + half, float, { return __half2float(val); }, DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float4, half4, + { + half4 dst; + dst.x = __floats2half2_rn(val.x, val.y); + dst.y = __floats2half2_rn(val.z, val.w); + return dst; + }, + DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float4_, half4, + { + half4 dst; + dst.x = __float22half2_rn(val.x); + dst.y = __float22half2_rn(val.y); + return dst; + }, + DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float8_, half8, + { + half8 dst; + dst.x = __float22half2_rn(val.x); + dst.y = __float22half2_rn(val.y); + dst.z = __float22half2_rn(val.z); + dst.w = __float22half2_rn(val.w); + return dst; + }, + DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, __half2float(val), DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, float, - __bfloat162float(val), DEVICE) - -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, __half22float2(val), DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, __float22half2_rn(val), - DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, __half2half2(val), DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float, __nv_bfloat162, { return __float2bfloat162_rn(val); }, DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float, __nv_bfloat16, { return __float2bfloat16_rn(val); }, DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float4, bfloat164, + { + bfloat164 dst; + dst.x = __floats2bfloat162_rn(val.x, val.y); + dst.y = __floats2bfloat162_rn(val.z, val.w); + return dst; + }, + DEVICE) +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + __nv_bfloat16, __nv_bfloat162, { return __bfloat162bfloat162(val); }, + DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + __nv_bfloat162, float2, { return __bfloat1622float2(val); }, DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float2, __nv_bfloat162, { return __float22bfloat162_rn(val); }, DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float4_, bfloat164, + { + bfloat164 dst; + dst.x = __float22bfloat162_rn(val.x); + dst.y = __float22bfloat162_rn(val.y); + return dst; + }, + DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float8_, bfloat168, + { + bfloat168 dst; + dst.x = __float22bfloat162_rn(val.x); + dst.y = __float22bfloat162_rn(val.y); + dst.z = __float22bfloat162_rn(val.z); + dst.w = __float22bfloat162_rn(val.w); + return dst; + }, + DEVICE) +#else +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + __nv_bfloat16, __nv_bfloat162, + { + __nv_bfloat162 dst; + dst.x = val; + dst.y = val; + return dst; + }, + DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + __nv_bfloat162, float2, + { return make_float2(__low2float(val), __high2float(val)); }, DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float2, __nv_bfloat162, { return __floats2bfloat162_rn(val.x, val.y); }, + DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float4_, bfloat164, + { + bfloat164 dst; + dst.x = __floats2bfloat162_rn(val.x.x, val.x.y); + dst.y = __floats2bfloat162_rn(val.y.x, val.y.y); + return dst; + }, + DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float8_, bfloat168, + { + bfloat168 dst; + dst.x = __floats2bfloat162_rn(val.x.x, val.x.y); + dst.y = __floats2bfloat162_rn(val.y.x, val.y.y); + dst.z = __floats2bfloat162_rn(val.z.x, val.z.y); + dst.w = __floats2bfloat162_rn(val.w.x, val.w.y); + return dst; + }, + DEVICE) +#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */ #undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION } // namespace funcs diff --git a/extensions/csrc/cuda/funcs/ternary_functor.h b/extensions/csrc/cuda/funcs/ternary_functor.h new file mode 100644 index 000000000..34b01cdf5 --- /dev/null +++ b/extensions/csrc/cuda/funcs/ternary_functor.h @@ -0,0 +1,212 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include "../funcs/cast_functor.h" +#include "../utils/micros.h" + +namespace colossalAI { +namespace cuda { +namespace funcs { + +enum class TernaryOpType { kFma = 0 }; + +template +struct TernaryOpFunctor; + +#define STMTS_WRAPPER(...) __VA_ARGS__ + +#define COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( \ + LT, RT, RET, TERNARY_OP_TYPE, FUNCTION_MODIFIER, STMTS, ARGS...) \ + template \ + struct TernaryOpFunctor { \ + FUNCTION_MODIFIER RET operator()(LT a, RT b, RET c) STMTS \ + }; + +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float, float, + TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float d; + d = fma(a, b, c); + return d; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float2, float2, float2, + TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float2, float2, + TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float4, float4, float4, + TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float4, float4, + TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; + })) + +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + half, half, float, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ return __half2float(a) * __half2float(b) + c; })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + half2, half2, float2, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ + CastFunctor cast; + TernaryOpFunctor fma; + float2 fa = cast(a); + float2 fb = cast(b); + return fma(fa, fb, c); + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + half, half2, float2, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ + CastFunctor cast; + TernaryOpFunctor fma; + return fma(cast(a), b, c); + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + half4, half4, float4_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ + float4_ fd; + TernaryOpFunctor fma; + fd.x = fma(a.x, b.x, c.x); + fd.y = fma(a.y, b.y, c.y); + return fd; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + half, half4, float4_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ + float4_ fd; + CastFunctor cast; + TernaryOpFunctor fma; + half2 s = cast(a); + fd.x = fma(s, b.x, c.x); + fd.y = fma(s, b.y, c.y); + return fd; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + half8, half8, float8_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ + float8_ fd; + TernaryOpFunctor fma; + fd.x = fma(a.x, b.x, c.x); + fd.y = fma(a.y, b.y, c.y); + fd.z = fma(a.z, b.z, c.z); + fd.w = fma(a.w, b.w, c.w); + return fd; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + half, half8, float8_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ + float8_ fd; + CastFunctor cast; + TernaryOpFunctor fma; + half2 s = cast(a); + fd.x = fma(s, b.x, c.x); + fd.y = fma(s, b.y, c.y); + fd.z = fma(s, b.z, c.z); + fd.w = fma(s, b.w, c.w); + return fd; + })) + +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat16, __nv_bfloat16, float, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ return __bfloat162float(a) * __bfloat162float(b) + c; })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat162, __nv_bfloat162, float2, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + CastFunctor<__nv_bfloat162, float2> cast; + TernaryOpFunctor fma; + float2 fa = cast(a); + float2 fb = cast(b); + return fma(fa, fb, c); + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat16, __nv_bfloat162, float2, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + CastFunctor<__nv_bfloat16, __nv_bfloat162> cast; + TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, + TernaryOpType::kFma> + fma; + return fma(cast(a), b, c); + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + bfloat164, bfloat164, float4_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ + float4_ fd; + TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, + TernaryOpType::kFma> + fma; + fd.x = fma(a.x, b.x, c.x); + fd.y = fma(a.y, b.y, c.y); + return fd; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat16, bfloat164, float4_, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float4_ fd; + CastFunctor<__nv_bfloat16, __nv_bfloat162> cast; + TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, + TernaryOpType::kFma> + fma; + __nv_bfloat162 s = cast(a); + fd.x = fma(s, b.x, c.x); + fd.y = fma(s, b.y, c.y); + return fd; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + bfloat168, bfloat168, float8_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ + float8_ fd; + TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, + TernaryOpType::kFma> + fma; + fd.x = fma(a.x, b.x, c.x); + fd.y = fma(a.y, b.y, c.y); + fd.z = fma(a.z, b.z, c.z); + fd.w = fma(a.w, b.w, c.w); + return fd; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat16, bfloat168, float8_, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float8_ fd; + CastFunctor<__nv_bfloat16, __nv_bfloat162> cast; + TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, + TernaryOpType::kFma> + fma; + __nv_bfloat162 s = cast(a); + fd.x = fma(s, b.x, c.x); + fd.y = fma(s, b.y, c.y); + fd.z = fma(s, b.z, c.z); + fd.w = fma(s, b.w, c.w); + return fd; + })) + +#undef COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION + +#undef STMTS_WRAPPER + +} // namespace funcs +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/cuda/funcs/unary_functor.h b/extensions/csrc/cuda/funcs/unary_functor.h index ea57fae7a..b8cd3c1a1 100644 --- a/extensions/csrc/cuda/funcs/unary_functor.h +++ b/extensions/csrc/cuda/funcs/unary_functor.h @@ -13,9 +13,24 @@ namespace colossalAI { namespace cuda { namespace funcs { +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ii++) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + // Note(LiuYang): As a retrieved table to check which operation is supported // already -enum class UnaryOpType { kLog2Ceil = 0, kAbs }; +enum class UnaryOpType { kLog2Ceil = 0, kAbs, kSum }; // Note(LiuYang): Implementation of common and simple unary operators should be // placed here, otherwise, they should be placed in a new file under functors @@ -42,6 +57,25 @@ COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(int, int, UnaryOpType::kLog2Ceil, return log2_value; }) +COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float2, float, UnaryOpType::kSum, DEVICE, + { return val.x + val.y; }) + +COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float4, float, UnaryOpType::kSum, DEVICE, + { return val.x + val.y + val.z + val.w; }) + +COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float4_, float, UnaryOpType::kSum, DEVICE, + { + return val.x.x + val.x.y + val.y.x + + val.y.y; + }) + +COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float8_, float, UnaryOpType::kSum, DEVICE, + { + return val.x.x + val.x.y + val.y.x + + val.y.y + val.z.x + val.z.y + + val.w.x + val.w.y; + }) + #undef COLOSSAL_UARY_FUNCTOR_SPECIALIZATION } // namespace funcs diff --git a/extensions/csrc/cuda/pybind/inference.cpp b/extensions/csrc/cuda/pybind/inference.cpp index 6a468fcb8..9997cc54c 100644 --- a/extensions/csrc/cuda/pybind/inference.cpp +++ b/extensions/csrc/cuda/pybind/inference.cpp @@ -58,6 +58,21 @@ void get_cos_and_sin(at::Tensor& cos_cache, // [max_rotary_position, head_dim] at::Tensor& sequence_lengths, // [batch_size] int max_seq_len_in_batch, bool is_prompts); +void flash_decoding_attention( + torch::Tensor& out, // [num_tokens, num_heads, head_size] + torch::Tensor& query, // [num_tokens, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& + value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& context_lens, // [num_tokens] + torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq] + int block_size, int max_context_len, + torch::Tensor& + tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size] + torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions] + float scale); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, "Copy the GPU memory of kvcache during the decode stage."); @@ -81,4 +96,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "In-place fused Add and RMS Normalization."); m.def("get_cos_and_sin", &get_cos_and_sin, "Get cos and sin from the cache."); + + m.def("flash_decoding_attention", &flash_decoding_attention, + "Compute the attention between an input query and the cached " + "keys/values using PagedAttention."); } diff --git a/extensions/csrc/cuda/rms_layernorm_kernel.cu b/extensions/csrc/cuda/rms_layernorm_kernel.cu index 1b89232f3..9183462ad 100644 --- a/extensions/csrc/cuda/rms_layernorm_kernel.cu +++ b/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -1,4 +1,4 @@ -/*This code from VLLM: +/*This code from FasterTransformer: * https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu * with minor changes. */ @@ -20,6 +20,32 @@ using colossalAI::cuda::funcs::BinaryOpFunctor; using colossalAI::cuda::funcs::BinaryOpType; using colossalAI::cuda::utils::VecTypeTrait; +#define RMSNORM_LAUNCHER(UNROLL_FACTOR, THREADDIM) \ + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( \ + input.element_size(), \ + input.scalar_type(), \ + "rms_layernorm_kernel", \ + rms_layernorm_kernel<<>>( \ + out.data_ptr(), \ + input.data_ptr(), \ + weight.data_ptr(), \ + epsilon, \ + num_tokens, \ + hidden_size);) \ + +#define FUSED_ADD_RMSNORM_LAUNCHER(UNROLL_FACTOR, THREADDIM) \ + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( \ + input.element_size(), \ + input.scalar_type(), \ + "fused_add_rms_layernorm_kernel", \ + fused_add_rms_layernorm_kernel<<>>( \ + input.data_ptr(), \ + residual.data_ptr(), \ + weight.data_ptr(), \ + epsilon, \ + num_tokens, \ + hidden_size);) \ + // optimized for half and bf16 template __global__ void rms_layernorm_kernel( @@ -234,29 +260,9 @@ void rms_layernorm( if (num_tokens >= 512) { if (input.scalar_type() == at::ScalarType::Float) { - DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( - input.element_size(), - input.scalar_type(), - "rms_layernorm_kernel", - rms_layernorm_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + RMSNORM_LAUNCHER(8, hidden_size / 8); } else { - DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( - input.element_size(), - input.scalar_type(), - "rms_layernorm_kernel", - rms_layernorm_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + RMSNORM_LAUNCHER(4, hidden_size / 8); } } else { int unroll_factor = (hidden_size + block.x - 1) / block.x; @@ -266,56 +272,16 @@ void rms_layernorm( } switch (unroll_factor) { case 1: - DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( - input.element_size(), - input.scalar_type(), - "rms_layernorm_kernel", - rms_layernorm_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + RMSNORM_LAUNCHER(1, block); break; case 2: - DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( - input.element_size(), - input.scalar_type(), - "rms_layernorm_kernel", - rms_layernorm_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + RMSNORM_LAUNCHER(2, block); break; case 4: - DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( - input.element_size(), - input.scalar_type(), - "rms_layernorm_kernel", - rms_layernorm_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + RMSNORM_LAUNCHER(4, block); break; case 8: - DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( - input.element_size(), - input.scalar_type(), - "rms_layernorm_kernel", - rms_layernorm_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + RMSNORM_LAUNCHER(8, block); break; default: AT_ERROR("unroll_factor must be 1, 2, 4 or 8"); @@ -338,29 +304,9 @@ void fused_add_rms_layernorm( if (num_tokens >= 512) { if (input.scalar_type() == at::ScalarType::Float) { - DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( - input.element_size(), - input.scalar_type(), - "fused_add_rms_layernorm_kernel", - fused_add_rms_layernorm_kernel<<>>( - input.data_ptr(), - residual.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + FUSED_ADD_RMSNORM_LAUNCHER(8, hidden_size / 8); } else { - DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( - input.element_size(), - input.scalar_type(), - "fused_add_rms_layernorm_kernel", - fused_add_rms_layernorm_kernel<<>>( - input.data_ptr(), - residual.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + FUSED_ADD_RMSNORM_LAUNCHER(4, hidden_size / 8); } } else { int unroll_factor = (hidden_size + block.x - 1) / block.x; @@ -370,56 +316,16 @@ void fused_add_rms_layernorm( } switch (unroll_factor) { case 1: - DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( - input.element_size(), - input.scalar_type(), - "fused_add_rms_layernorm_kernel", - fused_add_rms_layernorm_kernel<<>>( - input.data_ptr(), - residual.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + FUSED_ADD_RMSNORM_LAUNCHER(1, block); break; case 2: - DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( - input.element_size(), - input.scalar_type(), - "fused_add_rms_layernorm_kernel", - fused_add_rms_layernorm_kernel<<>>( - input.data_ptr(), - residual.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + FUSED_ADD_RMSNORM_LAUNCHER(2, block); break; case 4: - DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( - input.element_size(), - input.scalar_type(), - "fused_add_rms_layernorm_kernel", - fused_add_rms_layernorm_kernel<<>>( - input.data_ptr(), - residual.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + FUSED_ADD_RMSNORM_LAUNCHER(4, block); break; case 8: - DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( - input.element_size(), - input.scalar_type(), - "fused_add_rms_layernorm_kernel", - fused_add_rms_layernorm_kernel<<>>( - input.data_ptr(), - residual.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + FUSED_ADD_RMSNORM_LAUNCHER(8, block); break; default: AT_ERROR("unroll_factor must be 1, 2, 4 or 8"); diff --git a/extensions/csrc/cuda/utils/vec_type_traits.h b/extensions/csrc/cuda/utils/vec_type_traits.h index 782518936..3a78a93c8 100644 --- a/extensions/csrc/cuda/utils/vec_type_traits.h +++ b/extensions/csrc/cuda/utils/vec_type_traits.h @@ -11,9 +11,45 @@ namespace colossalAI { namespace cuda { namespace utils { +struct bfloat164 { + __nv_bfloat162 x; + __nv_bfloat162 y; +}; +struct bfloat168 { + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +}; + +struct half4 { + half2 x; + half2 y; +}; +struct half8 { + half2 x; + half2 y; + half2 z; + half2 w; +}; + +struct float4_ { + float2 x; + float2 y; +}; +struct float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + template struct VecTypeTrait {}; +template +struct FloatVecTypeTrait {}; + #define VEC_TYPE_TRAITS_SPECIALIZATION(T, VEC_SIZE, VECT, ARGS...) \ template \ struct VecTypeTrait { \ @@ -31,13 +67,36 @@ VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 4, float2) VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 8, float4) VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2) VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4) -VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, float4) +VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, float8_) VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, half) VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, half2) VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, float2) +VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 2, __nv_bfloat162); +VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 4, bfloat164); +VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 8, bfloat168); +VEC_TYPE_TRAITS_SPECIALIZATION(half, 2, half2); +VEC_TYPE_TRAITS_SPECIALIZATION(half, 4, half4); +VEC_TYPE_TRAITS_SPECIALIZATION(half, 8, half8); #undef VEC_TYPE_TRAITS_SPECIALIZATION +#define FLOATVEC_TYPE_TRAITS_SPECIALIZATION(T, FLOATT, ARGS...) \ + template \ + struct FloatVecTypeTrait { \ + using Type = FLOATT; \ + }; + +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float2, float2) +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float4, float4) +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat162, float2); +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(bfloat164, float4_); +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(bfloat168, float8_); +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(half2, float2); +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(half4, float4_); +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(half8, float8_); + +#undef FLOATVEC_TYPE_TRAITS_SPECIALIZATION + } // namespace utils } // namespace cuda } // namespace colossalAI diff --git a/extensions/inference/inference_ops_cuda.py b/extensions/inference/inference_ops_cuda.py index 09ebfdabd..1ad58f3ea 100644 --- a/extensions/inference/inference_ops_cuda.py +++ b/extensions/inference/inference_ops_cuda.py @@ -17,6 +17,7 @@ class InferenceOpsCudaExtension(_CudaExtension): "cuda/activation_kernel.cu", "cuda/rms_layernorm_kernel.cu", "cuda/get_cos_and_sin_kernel.cu", + "cuda/flash_decoding_attention_kernel.cu", ] ] return ret diff --git a/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py b/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py new file mode 100644 index 000000000..a7eb47a76 --- /dev/null +++ b/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py @@ -0,0 +1,274 @@ +from itertools import product + +import numpy as np +import pytest +import torch + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.utils import get_current_device + +inference_ops = InferenceOpsLoader().load() + +from tests.test_infer.test_ops.triton.kernel_utils import ( + convert_kv_unpad_to_padded, + create_attention_mask, + generate_caches_and_block_tables_v2, + generate_caches_and_block_tables_vllm, + torch_attn_ref, +) + +q_len = 1 + + +def prepare_data( + BATCH_SIZE: int, + HEAD_SIZE: int, + NUM_ATTN_HEADS: int, + NUM_KV_HEADS: int, + MAX_SEQ_LEN: int, + dtype=torch.float16, + device="cuda", +): + # Use the provided maximum sequence length for each sequence when testing with teh same context length, + # otherwise generate random context lengths. + # returns + # q [BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE] + # k_unpad/v_unpad [num_tokens, NUM_KV_HEADS, HEAD_SIZE] + kv_lengths = torch.randint(low=1, high=MAX_SEQ_LEN, size=(BATCH_SIZE,), dtype=torch.int32, device=device) + num_tokens = torch.sum(kv_lengths).item() + + q_size = (BATCH_SIZE, q_len, NUM_ATTN_HEADS, HEAD_SIZE) + q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2) + kv_size = (num_tokens, 2 * NUM_KV_HEADS, HEAD_SIZE) + kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + k_unpad, v_unpad = torch.split(kv_unpad, [NUM_KV_HEADS, NUM_KV_HEADS], dim=-2) + + return q, k_unpad, v_unpad, kv_lengths + + +def numpy_allclose(x, y, rtol, atol): + x_numpy = x.detach().cpu().numpy() + y_numpy = y.detach().cpu().numpy() + + np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32]) +@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32]) +@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32]) +@pytest.mark.parametrize("HEAD_SIZE", [64, 128]) +@pytest.mark.parametrize("NUM_ATTN_HEADS", [16]) +@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_flash_decoding_attention( + BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype +): + torch.manual_seed(123) + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM + assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads." + MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ + device = get_current_device() + + q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( + BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device + ) + + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( + k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device + ) + + block_tables = block_tables.to(device=device) + max_seq_len_across_batch = kv_seq_lengths.max().item() + kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE + output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device) + sm_scale = 1.0 / (HEAD_SIZE**0.5) + + k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) + v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) + torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device) + + mid_output = torch.empty( + size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device + ) + mid_output_lse = torch.empty( + size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device + ) + + if dtype == torch.float16: + rtol = 1e-3 + atol = 1e-3 + + high_precision_q = q.to(torch.float32) + high_precision_k_torch = k_torch.to(torch.float32) + high_precision_v_torch = v_torch.to(torch.float32) + out_ref = torch_attn_ref( + high_precision_q, + high_precision_k_torch, + high_precision_v_torch, + torch_padding_mask, + BATCH_SIZE, + q_len, + max_seq_len_across_batch, + NUM_ATTN_HEADS, + NUM_KV_HEADS, + HEAD_SIZE, + ).to(torch.float16) + + else: + rtol = 1e-5 + atol = 1e-7 + + out_ref = torch_attn_ref( + q, + k_torch, + v_torch, + torch_padding_mask, + BATCH_SIZE, + q_len, + max_seq_len_across_batch, + NUM_ATTN_HEADS, + NUM_KV_HEADS, + HEAD_SIZE, + ) + + inference_ops.flash_decoding_attention( + output, + q.squeeze(2), + k_cache, + v_cache, + kv_seq_lengths, + block_tables, + BLOCK_SIZE, + max_seq_len_across_batch, + mid_output, + mid_output_lse, + sm_scale, + ) + numpy_allclose(out_ref, output, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32]) +@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32]) +@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32]) +@pytest.mark.parametrize("HEAD_SIZE", [64, 128]) +@pytest.mark.parametrize("NUM_ATTN_HEADS", [16]) +@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_vllm_flash_decoding_attention( + BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype +): + torch.manual_seed(123) + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + try: + from vllm._C import ops as vllm_ops + except ImportError: + raise ImportError("Please install vllm from https://github.com/vllm-project/vllm") + + NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM + assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads." + MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ + device = get_current_device() + + q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( + BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device + ) + + k_cache, v_cache, block_tables = generate_caches_and_block_tables_vllm( + k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device + ) + + block_tables = block_tables.to(device=device) + max_seq_len_across_batch = kv_seq_lengths.max().item() + output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device) + sm_scale = 1.0 / (HEAD_SIZE**0.5) + + k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) + v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) + torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device) + + if dtype == torch.float16: + rtol = 1e-3 + atol = 1e-3 + + high_precision_q = q.to(torch.float32) + high_precision_k_torch = k_torch.to(torch.float32) + high_precision_v_torch = v_torch.to(torch.float32) + out_ref = torch_attn_ref( + high_precision_q, + high_precision_k_torch, + high_precision_v_torch, + torch_padding_mask, + BATCH_SIZE, + q_len, + max_seq_len_across_batch, + NUM_ATTN_HEADS, + NUM_KV_HEADS, + HEAD_SIZE, + ).to(torch.float16) + + else: + rtol = 1e-5 + atol = 1e-7 + + out_ref = torch_attn_ref( + q, + k_torch, + v_torch, + torch_padding_mask, + BATCH_SIZE, + q_len, + max_seq_len_across_batch, + NUM_ATTN_HEADS, + NUM_KV_HEADS, + HEAD_SIZE, + ) + + alibi_slopes = None + + vllm_ops.paged_attention_v1( + output, + q.squeeze(2), + k_cache, + v_cache, + NUM_KV_HEADS, + sm_scale, + block_tables, + kv_seq_lengths, + BLOCK_SIZE, + max_seq_len_across_batch, + alibi_slopes, + "auto", + ) + numpy_allclose(out_ref, output, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + BATCH_SIZE = [1, 4, 7, 32] + BLOCK_SIZE = [8, 16, 32] + MAX_NUM_BLOCKS_PER_SEQ = [1, 8, 32] + HEAD_SIZE = [64, 128] + NUM_ATTN_HEADS = [16] + KV_GROUP_NUM = [1, 2, 16] + DTYPE = [torch.float16, torch.float32] + test_combinations = list( + product(BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, DTYPE) + ) + for ( + batch_size, + block_size, + max_num_blocks_per_seq, + head_size, + num_attn_heads, + kv_group_num, + dtype, + ) in test_combinations: + test_flash_decoding_attention( + batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype + ) diff --git a/tests/test_infer/test_ops/triton/kernel_utils.py b/tests/test_infer/test_ops/triton/kernel_utils.py index 7ae5a833b..507c185b5 100644 --- a/tests/test_infer/test_ops/triton/kernel_utils.py +++ b/tests/test_infer/test_ops/triton/kernel_utils.py @@ -150,6 +150,51 @@ def mock_alloc_block_table_and_kvcache_v2( return block_tables +def mock_alloc_block_table_and_kvcache_vllm( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + context_lengths: torch.Tensor, + num_seqs: int, + max_num_blocks_per_seq: int, + block_size: int, +) -> torch.Tensor: + """Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache.""" + block_id = 0 + block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32) + num_tokens_processed = 0 + + _, num_kv_heads, head_dim = k.shape + + x = 16 // torch.tensor([], dtype=k.dtype).element_size() + + for i, seq_len in enumerate(context_lengths.tolist()): + right_bound = (seq_len + block_size - 1) // block_size # open bound + block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32) + # Manually fill kv caches by copying from k and v + for i in range(right_bound): + if i == right_bound - 1: + allocated_locs = seq_len % block_size or block_size + else: + allocated_locs = block_size + # [block_size, num_kv_heads, head_dim/x, x]->[num_kv_heads, head_dim/x, block_size,x] + k_block = ( + k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :] + .reshape(allocated_locs, num_kv_heads, head_dim // x, x) + .permute(1, 2, 0, 3) + ) + # [block_size, num_kv_heads, head_dim]->[num_kv_heads, head_dim, block_size] + v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0) + k_cache[block_id, :, :, :allocated_locs, :] = k_block + v_cache[block_id, :, :, :allocated_locs] = v_block + + num_tokens_processed += allocated_locs + block_id += 1 + + return block_tables + + def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None: # Allocate 1 token on the block table for each seqs in block tables. # It won't change provided context_lengths. @@ -206,6 +251,26 @@ def generate_caches_and_block_tables_v2( return k_cache, v_cache, block_tables +def generate_caches_and_block_tables_vllm( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda" +) -> Tuple[torch.Tensor, ...]: + # Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths + # k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim] + _, num_kv_heads, head_dim = k_unpad.shape + + x = 16 // torch.tensor([], dtype=dtype).element_size() + + k_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x) + v_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size) + k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device) + v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device=device) + # Mock allocation on block tables as well as blocked kv caches + block_tables = mock_alloc_block_table_and_kvcache_vllm( + k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size + ) + return k_cache, v_cache, block_tables + + def convert_kv_unpad_to_padded( k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int ) -> torch.Tensor: