From f366a5ea1f2626a7870acaf8866f21d5fb49c388 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Wed, 13 Mar 2024 17:20:03 +0800 Subject: [PATCH] [Inference/kernel]Add Fused Rotary Embedding and KVCache Memcopy CUDA Kernel (#5418) * add rotary embedding kernel * add rotary_embedding_kernel * add fused rotary_emb and kvcache memcopy * add fused_rotary_emb_and_cache_kernel.cu * add fused_rotary_emb_and_memcopy * fix bugs in fused_rotary_emb_and_cache_kernel.cu * fix ci bugs * use vec memcopy and opt the gloabl memory access * fix code style * fix test_rotary_embdding_unpad.py * codes revised based on the review comments * fix bugs about include path * rm inline --- .../modeling/models/nopadding_llama.py | 19 +- colossalai/inference/utils.py | 4 +- ... benchmark_fused_rotary_embdding_unpad.py} | 34 +- ...dding.py => benchmark_rotary_embedding.py} | 29 +- .../benchmark_ops/benchmark_xine_copy.py | 54 ++ extensions/csrc/common/vector_copy_utils.h | 98 ++++ extensions/csrc/cuda/activation_kernel.cu | 3 + .../cuda/decode_kv_cache_memcpy_kernel.cu | 163 ++++-- .../cuda/fused_rotary_emb_and_cache_kernel.cu | 472 ++++++++++++++++++ extensions/csrc/cuda/pybind/inference.cpp | 24 + extensions/inference/inference_ops_cuda.py | 1 + tests/test_infer/test_inference_engine.py | 14 +- .../cuda/test_rotary_embdding_unpad.py | 91 ++++ 13 files changed, 928 insertions(+), 78 deletions(-) rename examples/inference/benchmark_ops/{benchmark_rotary_embdding_unpad.py => benchmark_fused_rotary_embdding_unpad.py} (70%) rename examples/inference/benchmark_ops/{benchmark_fused_rotary_embedding.py => benchmark_rotary_embedding.py} (62%) create mode 100644 examples/inference/benchmark_ops/benchmark_xine_copy.py create mode 100644 extensions/csrc/common/vector_copy_utils.h create mode 100644 extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu create mode 100644 tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index f84abab4b..12de4802b 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -320,8 +320,12 @@ class NopadLlamaAttention(LlamaAttention): ) block_size = k_cache.size(-2) + if is_prompts: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + if use_cuda_kernel: + inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + else: + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) attn_output = context_attention_unpadded( q=query_states, k=key_states, @@ -337,9 +341,16 @@ class NopadLlamaAttention(LlamaAttention): ) else: if use_cuda_kernel: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - inference_ops.decode_kv_cache_memcpy( - key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables + inference_ops.rotary_embedding_and_cache_copy( + query_states, + key_states, + value_states, + cos_sin[0], + cos_sin[1], + k_cache, + v_cache, + sequence_lengths, + block_tables, ) else: decoding_fused_rotary_embedding( diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index 990864813..a97b9c9d6 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -47,5 +47,5 @@ def init_to_get_rotary(self, base=10000, use_elem=False): t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + self._cos_cached = torch.cos(freqs).to(self.dtype).cuda() + self._sin_cached = torch.sin(freqs).to(self.dtype).cuda() diff --git a/examples/inference/benchmark_ops/benchmark_rotary_embdding_unpad.py b/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py similarity index 70% rename from examples/inference/benchmark_ops/benchmark_rotary_embdding_unpad.py rename to examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py index 0e22ed7d2..f11630dff 100644 --- a/examples/inference/benchmark_ops/benchmark_rotary_embdding_unpad.py +++ b/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py @@ -1,8 +1,11 @@ import torch +from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token +inference_ops = InferenceOpsLoader().load() + try: import triton # noqa @@ -16,9 +19,19 @@ configs = [ x_names=["num_tokens"], x_vals=[2**i for i in range(4, 11)], line_arg="provider", - line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"], - line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"], - styles=[("red", "-"), ("blue", "-")], + line_vals=[ + "no_fused_triton_rotary_emb_func", + "fused_triton_rotary_emb_func", + "no_fused_cuda_rotary_emb_func", + "fused_cuda_rotary_emb_func", + ], + line_names=[ + "no_fused_triton_rotary_emb_func", + "fused_triton_rotary_emb_func", + "no_fused_cuda_rotary_emb_func", + "fused_cuda_rotary_emb_func", + ], + styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("yellow", "-")], ylabel="ms", plot_name=f"rotary_emb-batch-{BATCH}", args={"num_kv_heads": 16}, @@ -32,7 +45,7 @@ def benchmark_rotary_emb( num_tokens: int, num_kv_heads: int, ): - BATCH_SIZE = 4 + BATCH_SIZE = 16 SEQ_LEN = num_tokens // BATCH_SIZE max_num_blocks_per_seq = 8 block_size = 64 @@ -68,7 +81,7 @@ def benchmark_rotary_emb( kv_seq_lengths = past_kv_seq_lengths + 1 block_tables = block_tables.to(device="cuda") - if provider == "no_fused_rotary_emb_func": + if provider == "no_fused_triton_rotary_emb_func": fn = lambda: [ rotary_embedding(new_q, new_k, cos, sin), copy_kv_to_blocked_cache( @@ -77,7 +90,16 @@ def benchmark_rotary_emb( ] elif provider == "fused_triton_rotary_emb_func": fn = lambda: decoding_fused_rotary_embedding( - new_q, new_k, new_k, cos, sin, k_cache, k_cache, block_tables, kv_seq_lengths + new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths + ) + elif provider == "no_fused_cuda_rotary_emb_func": + fn = lambda: [ + inference_ops.rotary_embedding(new_q, new_k, cos, sin), + inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables), + ] + elif provider == "fused_cuda_rotary_emb_func": + fn = lambda: inference_ops.rotary_embedding_and_cache_copy( + new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables ) else: raise ValueError("Undefined provider") diff --git a/examples/inference/benchmark_ops/benchmark_fused_rotary_embedding.py b/examples/inference/benchmark_ops/benchmark_rotary_embedding.py similarity index 62% rename from examples/inference/benchmark_ops/benchmark_fused_rotary_embedding.py rename to examples/inference/benchmark_ops/benchmark_rotary_embedding.py index 9b44ef791..97cf2e0b2 100644 --- a/examples/inference/benchmark_ops/benchmark_fused_rotary_embedding.py +++ b/examples/inference/benchmark_ops/benchmark_rotary_embedding.py @@ -1,7 +1,11 @@ import torch import triton +from vllm._C import ops -from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import rotary_embedding + +inference_ops = InferenceOpsLoader().load() BATCH = 16 configs = [ @@ -9,9 +13,9 @@ configs = [ x_names=["num_tokens"], x_vals=[2**i for i in range(4, 12)], line_arg="provider", - line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"], - line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"], - styles=[("red", "-"), ("blue", "-")], + line_vals=["triton_func", "colossal_cuda_func", "vllm_cuda_func"], + line_names=["triton_func", "colossal_cuda_func", "vllm_cuda_func"], + styles=[("red", "-"), ("blue", "-"), ("yellow", "-")], ylabel="ms", plot_name=f"rotary_emb-batch-{BATCH}", args={"num_kv_heads": 16}, @@ -48,12 +52,19 @@ def benchmark_rotary_emb( cos_shape = (4096, head_dim // 2) cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - lengths = torch.tensor([3, 4, 6, 7], device="cuda") - if provider == "torch_rotary_emb_func": - fn = lambda: torch_rotary_emb(q, cos[:num_tokens], sin[:num_tokens]) - elif provider == "triton_rotary_emb_func": - fn = lambda: fused_rotary_embedding(q, k, cos, sin, lengths) + cos_sin = torch.stack((cos, sin), dim=1).contiguous() + + positions = torch.arange(num_tokens).cuda() + + if provider == "triton_func": + fn = lambda: rotary_embedding(q, k, cos, sin) + elif provider == "colossal_cuda_func": + fn = lambda: inference_ops.rotary_embedding(q, k, cos, sin) + elif provider == "vllm_cuda_func": + q = q.view(num_tokens, -1) + k = k.view(num_tokens, -1) + fn = lambda: ops.rotary_embedding(positions, q, k, head_dim, cos_sin, True) else: raise ValueError("Undefined provider") diff --git a/examples/inference/benchmark_ops/benchmark_xine_copy.py b/examples/inference/benchmark_ops/benchmark_xine_copy.py new file mode 100644 index 000000000..b15232b91 --- /dev/null +++ b/examples/inference/benchmark_ops/benchmark_xine_copy.py @@ -0,0 +1,54 @@ +import torch + +from colossalai.kernel.triton import get_xine_cache +from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin + +try: + import triton # noqa + +except ImportError: + print("please install triton from https://github.com/openai/triton") + + +configs = [ + triton.testing.Benchmark( + x_names=["max_num_tokens"], + x_vals=[2**i for i in range(6, 12)], + line_arg="provider", + line_vals=["torch_get_cos_sin", "triton_get_cos_sin"], + line_names=["torch_get_cos_sin", "triton_get_cos_sin"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name="Get_cos-sin_func", + args={"batch_size": 16, "head_dim": 256}, + ) +] + + +@triton.testing.perf_report(configs) +def benchmark_get_xine_cache( + provider: str, + max_num_tokens: int, + batch_size: int, + head_dim: int, +): + warmup = 10 + rep = 1000 + dtype = torch.float16 + cos_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda") + sin_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda") + lengths = torch.randint(2, max_num_tokens, (batch_size,), device="cuda") + + if provider == "torch_get_cos_sin": + fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype) + elif provider == "triton_get_cos_sin": + fn = lambda: get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True) + else: + raise ValueError("Undefined provider") + + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +if __name__ == "__main__": + benchmark_get_xine_cache.run(save_path=".", print_data=True) diff --git a/extensions/csrc/common/vector_copy_utils.h b/extensions/csrc/common/vector_copy_utils.h new file mode 100644 index 000000000..456440cf6 --- /dev/null +++ b/extensions/csrc/common/vector_copy_utils.h @@ -0,0 +1,98 @@ + +#include +#include + +#include + +#include "string" + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *((float *)dst) = *((float *)src); +} + +template <> +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *((float2 *)dst) = *((float2 *)src); +} + +template <> +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *((float4 *)dst) = *((float4 *)src); +} + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *((float *)dst) = *((float *)src); +} + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *((float2 *)dst) = *((float2 *)src); +} + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *((float4 *)dst) = *((float4 *)src); +} + +template <> +__device__ __inline__ void copy_vector(float *dst, const float *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector(float *dst, const float *src) { + *((float2 *)dst) = *((float2 *)src); +} + +template <> +__device__ __inline__ void copy_vector(float *dst, const float *src) { + *((float4 *)dst) = *((float4 *)src); +} + +template <> +__device__ __inline__ void copy_vector(float *dst, const float *src) { + // Since the maximum memory alignment length is 128 bits, we choose float4 + // here. + *((float4 *)dst) = *((float4 *)src); + *((float4 *)(dst + 4)) = *((float4 *)(src + 4)); +} + +template +int get_vec_size(const torch::Tensor &tensor) { + uint64_t address = reinterpret_cast(tensor.data_ptr()); + const int max_aligned_size = 128; + const int dtype_size = sizeof(T) * 8; + + const int vec_size = max_aligned_size / sizeof(T) / 8; + + if (address % (dtype_size * 4) == 0) { + return std::min(4, vec_size); + } else if (address % (dtype_size * 2) == 0) { + return std::min(2, vec_size); + } else { + return 1; + } +} diff --git a/extensions/csrc/cuda/activation_kernel.cu b/extensions/csrc/cuda/activation_kernel.cu index 5213a2313..e9dc01753 100644 --- a/extensions/csrc/cuda/activation_kernel.cu +++ b/extensions/csrc/cuda/activation_kernel.cu @@ -39,6 +39,9 @@ torch::Tensor silu_and_mul(const torch::Tensor& ins) auto ins_shape = ins.sizes().vec(); ins_shape[0] = ins_shape[0]/2; + if (ins_shape[0] == 1) { + ins_shape.erase(ins_shape.begin()); + } auto outs = torch::zeros(ins_shape,ins.options()); auto outs_shape = ins.sizes().vec(); diff --git a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu index 15e613e35..7eb44ecd0 100644 --- a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu @@ -1,10 +1,10 @@ #include #include -#include +#include "../common/vector_copy_utils.h" #include "../common/micros.h" -template +template __global__ void decode_kv_cache_memcpy_kernel( const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, @@ -12,79 +12,146 @@ __global__ void decode_kv_cache_memcpy_kernel( scalar_t* __restrict__ value_cache, const int* __restrict__ sequence_lengths, const int* __restrict__ block_tables, - const int num_heads, - const int head_size, + const int head_num, + const int head_dim, const int block_size, - const int key_stride, - const int value_stride, + const int64_t key_stride, + const int64_t value_stride, const int block_table_stride ) { const int seq_id = blockIdx.x; const int seq_len = sequence_lengths[seq_id] - 1; - const int seq_id_in_block_table = seq_len / block_size; const int block_offset = seq_len % block_size; - const int block_id = block_tables[seq_id * block_table_stride + seq_id_in_block_table]; - const int hidden_size = num_heads * head_size; + const int block_id = block_tables[seq_id * block_table_stride + seq_len / block_size]; + const int hidden_size = head_num * head_dim; if ( block_id < 0 ) { return ; } - for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - const int head_id = i / head_size; - const int head_offset = i % head_size; - const int key_src_id = seq_id * key_stride + i; - const int value_src_id = seq_id * value_stride + i; - const int target_src_id = block_id * hidden_size * block_size - + head_id * block_size * head_size - + block_offset * head_size + head_offset; + for (int i = threadIdx.x * VecSize; i < hidden_size; i += blockDim.x * VecSize) { + const int head_id = i / head_dim; + const int head_offset = i % head_dim; + const int64_t key_src_id = seq_id * key_stride + i; + const int64_t value_src_id = seq_id * value_stride + i; + const int64_t target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; - key_cache[target_src_id] = key[key_src_id]; - value_cache[target_src_id] = value[value_src_id]; + copy_vector(key_cache + target_id, key + key_src_id); + copy_vector(value_cache + target_id, value + value_src_id); } } -void decode_kv_cache_memcpy( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size] - torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size] - torch::Tensor& sequence_lengths, // [batch_size] - torch::Tensor& block_tables) // [batch_size, max_seq_len] +template +void apply_decode_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables) // [batch_size, max_seq_len] { int num_tokens = key.size(0); - int num_heads = key.size(1); - int head_size = key.size(2); + int head_num = key.size(1); + int head_dim = key.size(2); int block_size = key_cache.size(2); - int key_stride = key.stride(0); - int value_stride = value.stride(0); + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); int block_table_stride = block_tables.stride(0); + int vec_size = get_vec_size(key); + + if (head_dim % vec_size != 0) { + // Disable vectorized loading optimization when head_dim is not divisible by VecSize. + vec_size = 1; + } + + int thread_nums = head_num * head_dim / vec_size; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 grid(num_tokens); - dim3 block(std::min(num_heads * head_size, 512)); - DISPATCH_FLOAT_HALF_AND_BFLOAT( - key.scalar_type(), - "decode_kv_cache_memcpy", - decode_kv_cache_memcpy_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - sequence_lengths.data_ptr(), - block_tables.data_ptr(), - num_heads, - head_size, - block_size, - key_stride, - value_stride, - block_table_stride - );) + dim3 block(std::min(thread_nums, 512)); + + switch (vec_size) { + case 1: + decode_kv_cache_memcpy_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + head_num, + head_dim, + block_size, + key_stride, + value_stride, + block_table_stride + ); + break; + case 2: + decode_kv_cache_memcpy_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + head_num, + head_dim, + block_size, + key_stride, + value_stride, + block_table_stride + ); + break; + case 4: + decode_kv_cache_memcpy_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + head_num, + head_dim, + block_size, + key_stride, + value_stride, + block_table_stride + ); + break; + default: + AT_ERROR("Unsupported vectorized size ", vec_size); + break; + } AT_CUDA_CHECK(cudaGetLastError()); } + +void decode_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables) // [batch_size, max_seq_len] +{ + DISPATCH_FLOAT_HALF_AND_BFLOAT( + key.scalar_type(), + "decode_kv_cache_memcpy", + apply_decode_kv_cache_memcpy( + key, + value, + key_cache, + value_cache, + sequence_lengths, + block_tables + );) +} diff --git a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu new file mode 100644 index 000000000..c1db06d3f --- /dev/null +++ b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu @@ -0,0 +1,472 @@ + +#include +#include + +#include "../common/vector_copy_utils.h" +#include "../common/micros.h" + +template +__device__ void apply_emb_rotary_compute( + scalar_t* __restrict__ src, const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, const int64_t stride, + const int token_id, const int shard_block_size, const int half_head_dim, + const int head_num, const int head_dim) { + scalar_t x[VecSize]; + scalar_t y[VecSize]; + scalar_t out_x[VecSize]; + scalar_t out_y[VecSize]; + + for (int i = threadIdx.x * VecSize; i < head_num * half_head_dim; + i += blockDim.x * VecSize) { + const int head_offset = i % half_head_dim; + const int shard_offset = + (head_offset / shard_block_size) * shard_block_size + + (head_offset % shard_block_size) / VecSize; + const int64_t addr_offset = + token_id * stride + (i / half_head_dim) * head_dim + head_offset; + + copy_vector(x, src + addr_offset); + copy_vector(y, src + addr_offset + half_head_dim); + +#pragma unroll + for (int j = 0; j < VecSize; j++) { + out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] - + y[j] * sin_ptr[j * 32 + shard_offset]; + out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] + + x[j] * sin_ptr[j * 32 + shard_offset]; + } + + copy_vector(src + addr_offset, out_x); + copy_vector(src + addr_offset + half_head_dim, out_y); + } +} + +template +__device__ void apply_kv_memcopy( + scalar_t* __restrict__ src, scalar_t* __restrict__ cache, + const int64_t stride, const int token_id, const int block_id, + const int hidden_size, const int block_size, const int block_offset, + const int head_dim, const int half_head_dim) { + for (int i = threadIdx.x * VecSize; i < hidden_size / 2; + i += blockDim.x * VecSize) { + const int head_id = i / half_head_dim; + const int head_offset = i % half_head_dim; + const int64_t src_id = token_id * stride + head_id * head_dim + head_offset; + const int64_t target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + copy_vector(cache + target_id, src + src_id); + copy_vector(cache + target_id + half_head_dim, + src + src_id + half_head_dim); + } +} + +template +__device__ void cos_sin_memory_access( + const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin, + scalar_t* cos_ptr, scalar_t* sin_ptr, const int token_id, + const int shard_block_size, const int cos_stride, const int sin_stride, + const int half_head_dim) { + for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) { + // We assume that the value of head_dim is less than 128*128. + const int shard_offset = (i % shard_block_size) / VecSize; + const int shard_head = + (i / shard_block_size) * shard_block_size + i % VecSize * 32; + cos_ptr[shard_head + shard_offset] = cos[token_id * cos_stride + i]; + sin_ptr[shard_head + shard_offset] = sin[token_id * sin_stride + i]; + } +} + +template +__device__ void apply_k_rotary_emb_compute( + scalar_t* __restrict__ key, scalar_t* __restrict__ value, + scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, + const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr, + const int* __restrict__ sequence_lengths, + const int* __restrict__ block_tables, const int64_t key_stride, + const int64_t value_stride, const int token_id, + const int block_table_stride, const int head_num, const int head_dim, + const int kv_head_num, const int block_size, const int half_head_dim, + const int shard_block_size) { + const int seq_len = sequence_lengths[token_id] - 1; + const int block_offset = seq_len % block_size; + const int block_id = + block_tables[token_id * block_table_stride + seq_len / block_size]; + + if (block_id < 0) { + return; + } + + scalar_t x[VecSize]; + scalar_t y[VecSize]; + scalar_t out_x[VecSize]; + scalar_t out_y[VecSize]; + + for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim; + i += blockDim.x * VecSize) { + const int head_offset = i % half_head_dim; + const int shard_offset = + (head_offset / shard_block_size) * shard_block_size + + (head_offset % shard_block_size) / VecSize; + const int64_t addr_offset = + token_id * key_stride + (i / half_head_dim) * head_dim + head_offset; + const int64_t target_id = block_id * head_num * head_dim * block_size + + (i / half_head_dim) * block_size * head_dim + + block_offset * head_dim + head_offset; + + copy_vector(x, key + addr_offset); + copy_vector(y, key + addr_offset + half_head_dim); + +#pragma unroll + for (int j = 0; j < VecSize; j++) { + out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] - + y[j] * sin_ptr[j * 32 + shard_offset]; + out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] + + x[j] * sin_ptr[j * 32 + shard_offset]; + } + + copy_vector(key_cache + target_id, out_x); + copy_vector(key_cache + target_id + half_head_dim, + out_y); + } + + // apply value memcopy + apply_kv_memcopy( + value, value_cache, value_stride, token_id, block_id, head_num * head_dim, + block_size, block_offset, head_dim, half_head_dim); +} + +template +__global__ void rotary_embedding_and_cache_copy_kernel( + scalar_t* __restrict__ query, + scalar_t* __restrict__ key, + scalar_t* __restrict__ value, + const scalar_t* __restrict__ cos, + const scalar_t* __restrict__ sin, + scalar_t* __restrict__ key_cache, + scalar_t* __restrict__ value_cache, + const int* __restrict__ sequence_lengths, + const int* __restrict__ block_tables, + const int64_t query_stride, + const int64_t key_stride, + const int64_t value_stride, + const int64_t half_shard_element_num, + const int cos_stride, + const int sin_stride, + const int block_table_stride, + const int head_num, + const int head_dim, + const int kv_head_num, + const int block_size +) { + + const int token_id = blockIdx.x; + const int half_head_dim = head_dim / 2; + const int shard_block_size = VecSize * 32; + + extern __shared__ char shard_ptr[]; + + scalar_t *cos_ptr = (scalar_t*)shard_ptr; + scalar_t *sin_ptr = cos_ptr + half_shard_element_num; + + // apply cos_sin memcopy + cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); + __syncthreads(); + + //compute query + apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); + + //compute key and copy kv + apply_k_rotary_emb_compute(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size); +} + +template +__global__ void rotary_embedding_kernel( + scalar_t* __restrict__ query, + scalar_t* __restrict__ key, + const scalar_t* __restrict__ cos, + const scalar_t* __restrict__ sin, + const int64_t query_stride, + const int64_t key_stride, + const int64_t half_shard_element_num, + const int cos_stride, + const int sin_stride, + const int head_num, + const int head_dim, + const int kv_head_num +) { + const int token_id = blockIdx.x; + const int half_head_dim = head_dim / 2; + const int shard_block_size = VecSize * 32; + + extern __shared__ char shard_ptr[]; + + scalar_t *cos_ptr = (scalar_t*)shard_ptr; + scalar_t *sin_ptr = cos_ptr + half_shard_element_num; + + // apply cos_sin memcopy + cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); + __syncthreads(); + + //compute query + apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); + + //compute key + apply_emb_rotary_compute(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim); +} + +template +void apply_rotary_embedding_and_cache_copy( + at::Tensor& query, // [num_tokens, head_num, head_dim] + at::Tensor& key, // [num_tokens, kv_head_num, head_dim] + at::Tensor& value, // [num_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables) // [batch_size, max_seq_len] +{ + int num_tokens = query.size(0); + int head_num = query.size(1); + int head_dim = query.size(2); + int kv_head_num = key.size(1); + int block_size = key_cache.size(2); + + int64_t query_stride = query.stride(0); + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + int cos_stride = cos.stride(0); + int sin_stride = sin.stride(0); + int block_table_stride = block_tables.stride(0); + + int vec_size = get_vec_size(query); + + if ((head_dim / 2) % vec_size != 0) { + // Disable vectorized loading optimization when head_dim is not divisible by VecSize. + vec_size = 1; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int thread_nums = head_num * head_dim / vec_size / 2; + const int shard_block_size = vec_size * 32 * 2; + + dim3 grid(num_tokens); + dim3 block(std::min(thread_nums, 512)); + int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ; + + switch (vec_size) { + case 1: + rotary_embedding_and_cache_copy_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + value.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + query_stride, + key_stride, + value_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + block_table_stride, + head_num, + head_dim, + kv_head_num, + block_size + ); + break; + case 2: + rotary_embedding_and_cache_copy_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + value.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + query_stride, + key_stride, + value_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + block_table_stride, + head_num, + head_dim, + kv_head_num, + block_size + ); + break; + case 4: + rotary_embedding_and_cache_copy_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + value.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + query_stride, + key_stride, + value_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + block_table_stride, + head_num, + head_dim, + kv_head_num, + block_size + ); + break; + default: + AT_ERROR("Unsupported vectorized size ", vec_size); + break; + } + + AT_CUDA_CHECK(cudaGetLastError()); +} + +template +void apply_rotary_embedding( + at::Tensor& query, // [total_tokens, head_num, head_dim] + at::Tensor& key, // [total_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [total_tokens, head_dim] + at::Tensor& sin // [total_tokens, head_dim] +){ + int num_tokens = query.size(0); + int head_num = query.size(1); + int head_dim = query.size(2); + int kv_head_num = key.size(1); + + int query_stride = query.stride(0); + int key_stride = key.stride(0); + int cos_stride = cos.stride(0); + int sin_stride = sin.stride(0); + + int vec_size = get_vec_size(query); + + if ((head_dim / 2) % vec_size != 0) { + // Disable vectorized loading optimization when head_dim is not divisible by VecSize. + vec_size = 1; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int thread_nums = head_num * head_dim / vec_size / 2; + const int shard_block_size = vec_size * 32 * 2; + + dim3 grid(num_tokens); + dim3 block(std::min(thread_nums, 512)); + int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ; + + switch (vec_size) { + case 1: + rotary_embedding_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + query_stride, + key_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + head_num, + head_dim, + kv_head_num + ); + break; + case 2: + rotary_embedding_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + query_stride, + key_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + head_num, + head_dim, + kv_head_num + ); + break; + case 4: + rotary_embedding_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + query_stride, + key_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + head_num, + head_dim, + kv_head_num + ); + break; + default: + AT_ERROR("Unsupported vectorized size ", vec_size); + break; + } + AT_CUDA_CHECK(cudaGetLastError()); +} + +void rotary_embedding_and_cache_copy( + at::Tensor& query, // [num_tokens, head_num, head_dim] + at::Tensor& key, // [num_tokens, kv_head_num, head_dim] + at::Tensor& value, // [num_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables) // [batch_size, max_seq_len] +{ + DISPATCH_FLOAT_HALF_AND_BFLOAT( + query.scalar_type(), + "rotary_embedding_and_cache_copy", + apply_rotary_embedding_and_cache_copy( + query, + key, + value, + cos, + sin, + key_cache, + value_cache, + sequence_lengths, + block_tables + );) +} + +void rotary_embedding( + at::Tensor& query, // [total_tokens, head_num, head_dim] + at::Tensor& key, // [total_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [total_tokens, head_dim] + at::Tensor& sin // [total_tokens, head_dim] +){ + DISPATCH_FLOAT_HALF_AND_BFLOAT( + query.scalar_type(), + "rotary_embedding", + apply_rotary_embedding( + query, + key, + cos, + sin + );) +} diff --git a/extensions/csrc/cuda/pybind/inference.cpp b/extensions/csrc/cuda/pybind/inference.cpp index 73ed49e6c..4282f5382 100644 --- a/extensions/csrc/cuda/pybind/inference.cpp +++ b/extensions/csrc/cuda/pybind/inference.cpp @@ -9,6 +9,23 @@ void decode_kv_cache_memcpy( torch::Tensor& sequence_lengths, // [batch_size] torch::Tensor& block_tables); // [batch_size, max_seq_len] +void rotary_embedding( + torch::Tensor& query, // [total_tokens, head_num, head_dim] + torch::Tensor& key, // [total_tokens, kv_head_num, head_dim] + torch::Tensor& cos, // [total_tokens, head_dim] + torch::Tensor& sin); // [total_tokens, head_dim] + +void rotary_embedding_and_cache_copy( + torch::Tensor& query, // [num_tokens, head_num, head_dim] + torch::Tensor& key, // [num_tokens, kv_head_num, head_dim] + torch::Tensor& value, // [num_tokens, num_heads, head_dim] + torch::Tensor& cos, // [num_tokens, head_dim] + torch::Tensor& sin, // [num_tokens, head_dim] + torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim] + torch::Tensor& + value_cache, // [num_blocks, num_heads, block_size, head_dim] + torch::Tensor& sequence_lengths, // [batch_size] + torch::Tensor& block_tables); // [batch_size, max_seq_len] torch::Tensor silu_and_mul(const torch::Tensor& ins); void rms_layernorm(torch::Tensor& out, // [..., hidden_size] @@ -25,6 +42,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, "Copy the GPU memory of kvcache during the decode stage."); + m.def( + "rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy, + "performing Rotary Embedding-related calculations and KVCache Memcopy."); + + m.def("rotary_embedding", &rotary_embedding, + "performing Rotary Embedding-related calculations."); + m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply"); m.def("rms_layernorm", &rms_layernorm, diff --git a/extensions/inference/inference_ops_cuda.py b/extensions/inference/inference_ops_cuda.py index f465fe600..ae3754ca7 100644 --- a/extensions/inference/inference_ops_cuda.py +++ b/extensions/inference/inference_ops_cuda.py @@ -12,6 +12,7 @@ class InferenceOpsCudaExtension(_CudaExtension): for fname in [ "cuda/pybind/inference.cpp", "cuda/decode_kv_cache_memcpy_kernel.cu", + "cuda/fused_rotary_emb_and_cache_kernel.cu", "cuda/activation_kernel.cu", "cuda/rms_layernorm_kernel.cu", ] diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index edd92bb96..25b2c2f43 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -22,15 +22,11 @@ def setup_seed(seed): def check_inference_engine(use_engine=False, prompt_template=None): setup_seed(20) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - model = ( - LlamaForCausalLM( - LlamaConfig( - vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 - ) + model = LlamaForCausalLM( + LlamaConfig( + vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 ) - .cuda() - .half() - ) + ).cuda() model = model.eval() inputs = [ @@ -44,7 +40,7 @@ def check_inference_engine(use_engine=False, prompt_template=None): top_k = 50 if use_engine: - inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template) + inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32") inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) diff --git a/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py new file mode 100644 index 000000000..b9c0a3269 --- /dev/null +++ b/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py @@ -0,0 +1,91 @@ +import pytest +import torch +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb + +from colossalai.kernel.kernel_loader import InferenceOpsLoader + +inference_ops = InferenceOpsLoader().load() + +from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2 +from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb + + +@pytest.mark.parametrize("BATCH_SIZE", [4]) +@pytest.mark.parametrize("SEQ_LEN", [64]) +@pytest.mark.parametrize("H", [32]) +@pytest.mark.parametrize("D", [64]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): + torch.manual_seed(10) + TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN + # our crafted op equals to Transformers + x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) + x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) + emb = LlamaRotaryEmbedding(D) + cos, sin = emb(x0, TOTAL_TOKENS) + cos_2 = cos[:, : D // 2] + sin_2 = sin[:, : D // 2] + position_ids = torch.arange(TOTAL_TOKENS) + embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids) + embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2) + assert torch.allclose(embd_x0, embd_stimulated_x) + + # create data + block_size = 32 + max_blocks_per_sequence = (TOTAL_TOKENS + block_size - 1) // block_size + q_shape = (TOTAL_TOKENS, H, D) + q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") + k_shape = (TOTAL_TOKENS, H, D) + k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") + cos_shape = (TOTAL_TOKENS, D // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + cache_shape = (BATCH_SIZE * max_blocks_per_sequence, H, block_size, D) + k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") + v = torch.randn_like(k) + v_cache = torch.zeros_like(k_cache) + past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda") + block_tables = mock_alloc_block_table_and_kvcache_v2( + k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size + ) + new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda") + new_q = torch.randn_like(new_k) + new_v = torch.randn_like(new_k) + + kv_seq_lengths = past_kv_seq_lengths + 1 + block_tables = block_tables.to(device="cuda") + q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) + k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) + + new_q_copy = new_q.clone() + new_k_copy = new_k.clone() + + inference_ops.rotary_embedding_and_cache_copy( + new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables + ) + + inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin) + + past_kv_seq_len = kv_seq_lengths - 1 + target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] + offsets_in_block = past_kv_seq_len % block_size + k_target = k_cache[target_block_ids, :, offsets_in_block, :].squeeze() + k_source = new_k_copy.squeeze() + v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze() + v_source = new_v.squeeze() + + assert torch.allclose(new_q, q_ref, atol=1e-6, rtol=1e-6) + assert torch.allclose(k_target, k_ref, atol=1e-6, rtol=1e-6) + + assert torch.allclose(new_q_copy, q_ref, atol=1e-6, rtol=1e-6) + assert torch.allclose(new_k_copy, k_ref, atol=1e-6, rtol=1e-6) + + assert k_target.shape == k_source.shape + assert torch.allclose(k_target, k_source, atol=1e-6, rtol=1e-6) + + assert v_target.shape == v_source.shape + assert torch.equal(v_target, v_source) + + +if __name__ == "__main__": + test_rotary_emb(16, 512, 4, 128, torch.float16)