mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 06:52:46 +00:00
[Inference/kernel]Add Fused Rotary Embedding and KVCache Memcopy CUDA Kernel (#5418)
* add rotary embedding kernel * add rotary_embedding_kernel * add fused rotary_emb and kvcache memcopy * add fused_rotary_emb_and_cache_kernel.cu * add fused_rotary_emb_and_memcopy * fix bugs in fused_rotary_emb_and_cache_kernel.cu * fix ci bugs * use vec memcopy and opt the gloabl memory access * fix code style * fix test_rotary_embdding_unpad.py * codes revised based on the review comments * fix bugs about include path * rm inline
This commit is contained in:
parent
ed431de4e4
commit
f366a5ea1f
@ -320,7 +320,11 @@ class NopadLlamaAttention(LlamaAttention):
|
|||||||
)
|
)
|
||||||
|
|
||||||
block_size = k_cache.size(-2)
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
if is_prompts:
|
if is_prompts:
|
||||||
|
if use_cuda_kernel:
|
||||||
|
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||||
|
else:
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||||
attn_output = context_attention_unpadded(
|
attn_output = context_attention_unpadded(
|
||||||
q=query_states,
|
q=query_states,
|
||||||
@ -337,9 +341,16 @@ class NopadLlamaAttention(LlamaAttention):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if use_cuda_kernel:
|
if use_cuda_kernel:
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
inference_ops.rotary_embedding_and_cache_copy(
|
||||||
inference_ops.decode_kv_cache_memcpy(
|
query_states,
|
||||||
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cos_sin[0],
|
||||||
|
cos_sin[1],
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
sequence_lengths,
|
||||||
|
block_tables,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
decoding_fused_rotary_embedding(
|
decoding_fused_rotary_embedding(
|
||||||
|
@ -47,5 +47,5 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
|
|||||||
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
||||||
freqs = torch.outer(t, inv_freq)
|
freqs = torch.outer(t, inv_freq)
|
||||||
|
|
||||||
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
|
self._cos_cached = torch.cos(freqs).to(self.dtype).cuda()
|
||||||
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
|
self._sin_cached = torch.sin(freqs).to(self.dtype).cuda()
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
|
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
|
||||||
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token
|
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton # noqa
|
import triton # noqa
|
||||||
|
|
||||||
@ -16,9 +19,19 @@ configs = [
|
|||||||
x_names=["num_tokens"],
|
x_names=["num_tokens"],
|
||||||
x_vals=[2**i for i in range(4, 11)],
|
x_vals=[2**i for i in range(4, 11)],
|
||||||
line_arg="provider",
|
line_arg="provider",
|
||||||
line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
|
line_vals=[
|
||||||
line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
|
"no_fused_triton_rotary_emb_func",
|
||||||
styles=[("red", "-"), ("blue", "-")],
|
"fused_triton_rotary_emb_func",
|
||||||
|
"no_fused_cuda_rotary_emb_func",
|
||||||
|
"fused_cuda_rotary_emb_func",
|
||||||
|
],
|
||||||
|
line_names=[
|
||||||
|
"no_fused_triton_rotary_emb_func",
|
||||||
|
"fused_triton_rotary_emb_func",
|
||||||
|
"no_fused_cuda_rotary_emb_func",
|
||||||
|
"fused_cuda_rotary_emb_func",
|
||||||
|
],
|
||||||
|
styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("yellow", "-")],
|
||||||
ylabel="ms",
|
ylabel="ms",
|
||||||
plot_name=f"rotary_emb-batch-{BATCH}",
|
plot_name=f"rotary_emb-batch-{BATCH}",
|
||||||
args={"num_kv_heads": 16},
|
args={"num_kv_heads": 16},
|
||||||
@ -32,7 +45,7 @@ def benchmark_rotary_emb(
|
|||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
):
|
):
|
||||||
BATCH_SIZE = 4
|
BATCH_SIZE = 16
|
||||||
SEQ_LEN = num_tokens // BATCH_SIZE
|
SEQ_LEN = num_tokens // BATCH_SIZE
|
||||||
max_num_blocks_per_seq = 8
|
max_num_blocks_per_seq = 8
|
||||||
block_size = 64
|
block_size = 64
|
||||||
@ -68,7 +81,7 @@ def benchmark_rotary_emb(
|
|||||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||||
block_tables = block_tables.to(device="cuda")
|
block_tables = block_tables.to(device="cuda")
|
||||||
|
|
||||||
if provider == "no_fused_rotary_emb_func":
|
if provider == "no_fused_triton_rotary_emb_func":
|
||||||
fn = lambda: [
|
fn = lambda: [
|
||||||
rotary_embedding(new_q, new_k, cos, sin),
|
rotary_embedding(new_q, new_k, cos, sin),
|
||||||
copy_kv_to_blocked_cache(
|
copy_kv_to_blocked_cache(
|
||||||
@ -77,7 +90,16 @@ def benchmark_rotary_emb(
|
|||||||
]
|
]
|
||||||
elif provider == "fused_triton_rotary_emb_func":
|
elif provider == "fused_triton_rotary_emb_func":
|
||||||
fn = lambda: decoding_fused_rotary_embedding(
|
fn = lambda: decoding_fused_rotary_embedding(
|
||||||
new_q, new_k, new_k, cos, sin, k_cache, k_cache, block_tables, kv_seq_lengths
|
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths
|
||||||
|
)
|
||||||
|
elif provider == "no_fused_cuda_rotary_emb_func":
|
||||||
|
fn = lambda: [
|
||||||
|
inference_ops.rotary_embedding(new_q, new_k, cos, sin),
|
||||||
|
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables),
|
||||||
|
]
|
||||||
|
elif provider == "fused_cuda_rotary_emb_func":
|
||||||
|
fn = lambda: inference_ops.rotary_embedding_and_cache_copy(
|
||||||
|
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Undefined provider")
|
raise ValueError("Undefined provider")
|
@ -1,7 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
from vllm._C import ops
|
||||||
|
|
||||||
from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
|
from colossalai.kernel.triton import rotary_embedding
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
BATCH = 16
|
BATCH = 16
|
||||||
configs = [
|
configs = [
|
||||||
@ -9,9 +13,9 @@ configs = [
|
|||||||
x_names=["num_tokens"],
|
x_names=["num_tokens"],
|
||||||
x_vals=[2**i for i in range(4, 12)],
|
x_vals=[2**i for i in range(4, 12)],
|
||||||
line_arg="provider",
|
line_arg="provider",
|
||||||
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
line_vals=["triton_func", "colossal_cuda_func", "vllm_cuda_func"],
|
||||||
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
line_names=["triton_func", "colossal_cuda_func", "vllm_cuda_func"],
|
||||||
styles=[("red", "-"), ("blue", "-")],
|
styles=[("red", "-"), ("blue", "-"), ("yellow", "-")],
|
||||||
ylabel="ms",
|
ylabel="ms",
|
||||||
plot_name=f"rotary_emb-batch-{BATCH}",
|
plot_name=f"rotary_emb-batch-{BATCH}",
|
||||||
args={"num_kv_heads": 16},
|
args={"num_kv_heads": 16},
|
||||||
@ -48,12 +52,19 @@ def benchmark_rotary_emb(
|
|||||||
cos_shape = (4096, head_dim // 2)
|
cos_shape = (4096, head_dim // 2)
|
||||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
lengths = torch.tensor([3, 4, 6, 7], device="cuda")
|
|
||||||
|
|
||||||
if provider == "torch_rotary_emb_func":
|
cos_sin = torch.stack((cos, sin), dim=1).contiguous()
|
||||||
fn = lambda: torch_rotary_emb(q, cos[:num_tokens], sin[:num_tokens])
|
|
||||||
elif provider == "triton_rotary_emb_func":
|
positions = torch.arange(num_tokens).cuda()
|
||||||
fn = lambda: fused_rotary_embedding(q, k, cos, sin, lengths)
|
|
||||||
|
if provider == "triton_func":
|
||||||
|
fn = lambda: rotary_embedding(q, k, cos, sin)
|
||||||
|
elif provider == "colossal_cuda_func":
|
||||||
|
fn = lambda: inference_ops.rotary_embedding(q, k, cos, sin)
|
||||||
|
elif provider == "vllm_cuda_func":
|
||||||
|
q = q.view(num_tokens, -1)
|
||||||
|
k = k.view(num_tokens, -1)
|
||||||
|
fn = lambda: ops.rotary_embedding(positions, q, k, head_dim, cos_sin, True)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Undefined provider")
|
raise ValueError("Undefined provider")
|
||||||
|
|
54
examples/inference/benchmark_ops/benchmark_xine_copy.py
Normal file
54
examples/inference/benchmark_ops/benchmark_xine_copy.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.kernel.triton import get_xine_cache
|
||||||
|
from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin
|
||||||
|
|
||||||
|
try:
|
||||||
|
import triton # noqa
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
|
|
||||||
|
configs = [
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["max_num_tokens"],
|
||||||
|
x_vals=[2**i for i in range(6, 12)],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["torch_get_cos_sin", "triton_get_cos_sin"],
|
||||||
|
line_names=["torch_get_cos_sin", "triton_get_cos_sin"],
|
||||||
|
styles=[("red", "-"), ("blue", "-")],
|
||||||
|
ylabel="ms",
|
||||||
|
plot_name="Get_cos-sin_func",
|
||||||
|
args={"batch_size": 16, "head_dim": 256},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(configs)
|
||||||
|
def benchmark_get_xine_cache(
|
||||||
|
provider: str,
|
||||||
|
max_num_tokens: int,
|
||||||
|
batch_size: int,
|
||||||
|
head_dim: int,
|
||||||
|
):
|
||||||
|
warmup = 10
|
||||||
|
rep = 1000
|
||||||
|
dtype = torch.float16
|
||||||
|
cos_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
|
||||||
|
sin_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
|
||||||
|
lengths = torch.randint(2, max_num_tokens, (batch_size,), device="cuda")
|
||||||
|
|
||||||
|
if provider == "torch_get_cos_sin":
|
||||||
|
fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
|
||||||
|
elif provider == "triton_get_cos_sin":
|
||||||
|
fn = lambda: get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True)
|
||||||
|
else:
|
||||||
|
raise ValueError("Undefined provider")
|
||||||
|
|
||||||
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||||
|
return ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
benchmark_get_xine_cache.run(save_path=".", print_data=True)
|
98
extensions/csrc/common/vector_copy_utils.h
Normal file
98
extensions/csrc/common/vector_copy_utils.h
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
|
||||||
|
#include <c10/macros/Macros.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#include <cfloat>
|
||||||
|
|
||||||
|
#include "string"
|
||||||
|
|
||||||
|
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||||
|
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
|
||||||
|
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||||
|
*dst = *src;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::BFloat16, 2>(
|
||||||
|
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||||
|
*((float *)dst) = *((float *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
|
||||||
|
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||||
|
*((float2 *)dst) = *((float2 *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::BFloat16, 8>(
|
||||||
|
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||||
|
*((float4 *)dst) = *((float4 *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
|
||||||
|
const c10::Half *src) {
|
||||||
|
*dst = *src;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::Half, 2>(c10::Half *dst,
|
||||||
|
const c10::Half *src) {
|
||||||
|
*((float *)dst) = *((float *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
|
||||||
|
const c10::Half *src) {
|
||||||
|
*((float2 *)dst) = *((float2 *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::Half, 8>(c10::Half *dst,
|
||||||
|
const c10::Half *src) {
|
||||||
|
*((float4 *)dst) = *((float4 *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) {
|
||||||
|
*dst = *src;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<float, 2>(float *dst, const float *src) {
|
||||||
|
*((float2 *)dst) = *((float2 *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<float, 4>(float *dst, const float *src) {
|
||||||
|
*((float4 *)dst) = *((float4 *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
|
||||||
|
// Since the maximum memory alignment length is 128 bits, we choose float4
|
||||||
|
// here.
|
||||||
|
*((float4 *)dst) = *((float4 *)src);
|
||||||
|
*((float4 *)(dst + 4)) = *((float4 *)(src + 4));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
int get_vec_size(const torch::Tensor &tensor) {
|
||||||
|
uint64_t address = reinterpret_cast<uint64_t>(tensor.data_ptr<T>());
|
||||||
|
const int max_aligned_size = 128;
|
||||||
|
const int dtype_size = sizeof(T) * 8;
|
||||||
|
|
||||||
|
const int vec_size = max_aligned_size / sizeof(T) / 8;
|
||||||
|
|
||||||
|
if (address % (dtype_size * 4) == 0) {
|
||||||
|
return std::min(4, vec_size);
|
||||||
|
} else if (address % (dtype_size * 2) == 0) {
|
||||||
|
return std::min(2, vec_size);
|
||||||
|
} else {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
@ -39,6 +39,9 @@ torch::Tensor silu_and_mul(const torch::Tensor& ins)
|
|||||||
auto ins_shape = ins.sizes().vec();
|
auto ins_shape = ins.sizes().vec();
|
||||||
|
|
||||||
ins_shape[0] = ins_shape[0]/2;
|
ins_shape[0] = ins_shape[0]/2;
|
||||||
|
if (ins_shape[0] == 1) {
|
||||||
|
ins_shape.erase(ins_shape.begin());
|
||||||
|
}
|
||||||
auto outs = torch::zeros(ins_shape,ins.options());
|
auto outs = torch::zeros(ins_shape,ins.options());
|
||||||
auto outs_shape = ins.sizes().vec();
|
auto outs_shape = ins.sizes().vec();
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <stdio.h>
|
|
||||||
|
|
||||||
|
#include "../common/vector_copy_utils.h"
|
||||||
#include "../common/micros.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t, int VecSize>
|
||||||
__global__ void decode_kv_cache_memcpy_kernel(
|
__global__ void decode_kv_cache_memcpy_kernel(
|
||||||
const scalar_t* __restrict__ key,
|
const scalar_t* __restrict__ key,
|
||||||
const scalar_t* __restrict__ value,
|
const scalar_t* __restrict__ value,
|
||||||
@ -12,79 +12,146 @@ __global__ void decode_kv_cache_memcpy_kernel(
|
|||||||
scalar_t* __restrict__ value_cache,
|
scalar_t* __restrict__ value_cache,
|
||||||
const int* __restrict__ sequence_lengths,
|
const int* __restrict__ sequence_lengths,
|
||||||
const int* __restrict__ block_tables,
|
const int* __restrict__ block_tables,
|
||||||
const int num_heads,
|
const int head_num,
|
||||||
const int head_size,
|
const int head_dim,
|
||||||
const int block_size,
|
const int block_size,
|
||||||
const int key_stride,
|
const int64_t key_stride,
|
||||||
const int value_stride,
|
const int64_t value_stride,
|
||||||
const int block_table_stride
|
const int block_table_stride
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
const int seq_id = blockIdx.x;
|
const int seq_id = blockIdx.x;
|
||||||
const int seq_len = sequence_lengths[seq_id] - 1;
|
const int seq_len = sequence_lengths[seq_id] - 1;
|
||||||
const int seq_id_in_block_table = seq_len / block_size;
|
|
||||||
const int block_offset = seq_len % block_size;
|
const int block_offset = seq_len % block_size;
|
||||||
const int block_id = block_tables[seq_id * block_table_stride + seq_id_in_block_table];
|
const int block_id = block_tables[seq_id * block_table_stride + seq_len / block_size];
|
||||||
const int hidden_size = num_heads * head_size;
|
const int hidden_size = head_num * head_dim;
|
||||||
|
|
||||||
if ( block_id < 0 ) {
|
if ( block_id < 0 ) {
|
||||||
return ;
|
return ;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
for (int i = threadIdx.x * VecSize; i < hidden_size; i += blockDim.x * VecSize) {
|
||||||
const int head_id = i / head_size;
|
const int head_id = i / head_dim;
|
||||||
const int head_offset = i % head_size;
|
const int head_offset = i % head_dim;
|
||||||
const int key_src_id = seq_id * key_stride + i;
|
const int64_t key_src_id = seq_id * key_stride + i;
|
||||||
const int value_src_id = seq_id * value_stride + i;
|
const int64_t value_src_id = seq_id * value_stride + i;
|
||||||
const int target_src_id = block_id * hidden_size * block_size
|
const int64_t target_id = block_id * hidden_size * block_size
|
||||||
+ head_id * block_size * head_size
|
+ head_id * block_size * head_dim
|
||||||
+ block_offset * head_size + head_offset;
|
+ block_offset * head_dim + head_offset;
|
||||||
|
|
||||||
key_cache[target_src_id] = key[key_src_id];
|
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
|
||||||
value_cache[target_src_id] = value[value_src_id];
|
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void decode_kv_cache_memcpy(
|
template<typename scalar_t>
|
||||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
void apply_decode_kv_cache_memcpy(
|
||||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
|
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size]
|
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
torch::Tensor& sequence_lengths, // [batch_size]
|
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
torch::Tensor& block_tables) // [batch_size, max_seq_len]
|
at::Tensor& sequence_lengths, // [batch_size]
|
||||||
|
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||||
{
|
{
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int head_num = key.size(1);
|
||||||
int head_size = key.size(2);
|
int head_dim = key.size(2);
|
||||||
int block_size = key_cache.size(2);
|
int block_size = key_cache.size(2);
|
||||||
|
|
||||||
int key_stride = key.stride(0);
|
int64_t key_stride = key.stride(0);
|
||||||
int value_stride = value.stride(0);
|
int64_t value_stride = value.stride(0);
|
||||||
int block_table_stride = block_tables.stride(0);
|
int block_table_stride = block_tables.stride(0);
|
||||||
|
|
||||||
|
int vec_size = get_vec_size<scalar_t>(key);
|
||||||
|
|
||||||
|
if (head_dim % vec_size != 0) {
|
||||||
|
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
|
||||||
|
vec_size = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int thread_nums = head_num * head_dim / vec_size;
|
||||||
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(thread_nums, 512));
|
||||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
|
||||||
key.scalar_type(),
|
switch (vec_size) {
|
||||||
"decode_kv_cache_memcpy",
|
case 1:
|
||||||
decode_kv_cache_memcpy_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
decode_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(),
|
||||||
value.data_ptr<scalar_t>(),
|
value.data_ptr<scalar_t>(),
|
||||||
key_cache.data_ptr<scalar_t>(),
|
key_cache.data_ptr<scalar_t>(),
|
||||||
value_cache.data_ptr<scalar_t>(),
|
value_cache.data_ptr<scalar_t>(),
|
||||||
sequence_lengths.data_ptr<int>(),
|
sequence_lengths.data_ptr<int>(),
|
||||||
block_tables.data_ptr<int>(),
|
block_tables.data_ptr<int>(),
|
||||||
num_heads,
|
head_num,
|
||||||
head_size,
|
head_dim,
|
||||||
block_size,
|
block_size,
|
||||||
key_stride,
|
key_stride,
|
||||||
value_stride,
|
value_stride,
|
||||||
block_table_stride
|
block_table_stride
|
||||||
);)
|
);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
decode_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
value.data_ptr<scalar_t>(),
|
||||||
|
key_cache.data_ptr<scalar_t>(),
|
||||||
|
value_cache.data_ptr<scalar_t>(),
|
||||||
|
sequence_lengths.data_ptr<int>(),
|
||||||
|
block_tables.data_ptr<int>(),
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
block_size,
|
||||||
|
key_stride,
|
||||||
|
value_stride,
|
||||||
|
block_table_stride
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
decode_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
value.data_ptr<scalar_t>(),
|
||||||
|
key_cache.data_ptr<scalar_t>(),
|
||||||
|
value_cache.data_ptr<scalar_t>(),
|
||||||
|
sequence_lengths.data_ptr<int>(),
|
||||||
|
block_tables.data_ptr<int>(),
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
block_size,
|
||||||
|
key_stride,
|
||||||
|
value_stride,
|
||||||
|
block_table_stride
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
AT_ERROR("Unsupported vectorized size ", vec_size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void decode_kv_cache_memcpy(
|
||||||
|
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||||
|
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||||
|
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
|
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
|
at::Tensor& sequence_lengths, // [batch_size]
|
||||||
|
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||||
|
{
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
key.scalar_type(),
|
||||||
|
"decode_kv_cache_memcpy",
|
||||||
|
apply_decode_kv_cache_memcpy<scalar_t>(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
sequence_lengths,
|
||||||
|
block_tables
|
||||||
|
);)
|
||||||
|
}
|
||||||
|
472
extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu
Normal file
472
extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu
Normal file
@ -0,0 +1,472 @@
|
|||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include "../common/vector_copy_utils.h"
|
||||||
|
#include "../common/micros.h"
|
||||||
|
|
||||||
|
template <typename scalar_t, int VecSize>
|
||||||
|
__device__ void apply_emb_rotary_compute(
|
||||||
|
scalar_t* __restrict__ src, const scalar_t* __restrict__ cos_ptr,
|
||||||
|
const scalar_t* __restrict__ sin_ptr, const int64_t stride,
|
||||||
|
const int token_id, const int shard_block_size, const int half_head_dim,
|
||||||
|
const int head_num, const int head_dim) {
|
||||||
|
scalar_t x[VecSize];
|
||||||
|
scalar_t y[VecSize];
|
||||||
|
scalar_t out_x[VecSize];
|
||||||
|
scalar_t out_y[VecSize];
|
||||||
|
|
||||||
|
for (int i = threadIdx.x * VecSize; i < head_num * half_head_dim;
|
||||||
|
i += blockDim.x * VecSize) {
|
||||||
|
const int head_offset = i % half_head_dim;
|
||||||
|
const int shard_offset =
|
||||||
|
(head_offset / shard_block_size) * shard_block_size +
|
||||||
|
(head_offset % shard_block_size) / VecSize;
|
||||||
|
const int64_t addr_offset =
|
||||||
|
token_id * stride + (i / half_head_dim) * head_dim + head_offset;
|
||||||
|
|
||||||
|
copy_vector<scalar_t, VecSize>(x, src + addr_offset);
|
||||||
|
copy_vector<scalar_t, VecSize>(y, src + addr_offset + half_head_dim);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < VecSize; j++) {
|
||||||
|
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] -
|
||||||
|
y[j] * sin_ptr[j * 32 + shard_offset];
|
||||||
|
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] +
|
||||||
|
x[j] * sin_ptr[j * 32 + shard_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
copy_vector<scalar_t, VecSize>(src + addr_offset, out_x);
|
||||||
|
copy_vector<scalar_t, VecSize>(src + addr_offset + half_head_dim, out_y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, int VecSize>
|
||||||
|
__device__ void apply_kv_memcopy(
|
||||||
|
scalar_t* __restrict__ src, scalar_t* __restrict__ cache,
|
||||||
|
const int64_t stride, const int token_id, const int block_id,
|
||||||
|
const int hidden_size, const int block_size, const int block_offset,
|
||||||
|
const int head_dim, const int half_head_dim) {
|
||||||
|
for (int i = threadIdx.x * VecSize; i < hidden_size / 2;
|
||||||
|
i += blockDim.x * VecSize) {
|
||||||
|
const int head_id = i / half_head_dim;
|
||||||
|
const int head_offset = i % half_head_dim;
|
||||||
|
const int64_t src_id = token_id * stride + head_id * head_dim + head_offset;
|
||||||
|
const int64_t target_id = block_id * hidden_size * block_size +
|
||||||
|
head_id * block_size * head_dim +
|
||||||
|
block_offset * head_dim + head_offset;
|
||||||
|
|
||||||
|
copy_vector<scalar_t, VecSize>(cache + target_id, src + src_id);
|
||||||
|
copy_vector<scalar_t, VecSize>(cache + target_id + half_head_dim,
|
||||||
|
src + src_id + half_head_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, int VecSize>
|
||||||
|
__device__ void cos_sin_memory_access(
|
||||||
|
const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin,
|
||||||
|
scalar_t* cos_ptr, scalar_t* sin_ptr, const int token_id,
|
||||||
|
const int shard_block_size, const int cos_stride, const int sin_stride,
|
||||||
|
const int half_head_dim) {
|
||||||
|
for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) {
|
||||||
|
// We assume that the value of head_dim is less than 128*128.
|
||||||
|
const int shard_offset = (i % shard_block_size) / VecSize;
|
||||||
|
const int shard_head =
|
||||||
|
(i / shard_block_size) * shard_block_size + i % VecSize * 32;
|
||||||
|
cos_ptr[shard_head + shard_offset] = cos[token_id * cos_stride + i];
|
||||||
|
sin_ptr[shard_head + shard_offset] = sin[token_id * sin_stride + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, int VecSize>
|
||||||
|
__device__ void apply_k_rotary_emb_compute(
|
||||||
|
scalar_t* __restrict__ key, scalar_t* __restrict__ value,
|
||||||
|
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||||
|
const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr,
|
||||||
|
const int* __restrict__ sequence_lengths,
|
||||||
|
const int* __restrict__ block_tables, const int64_t key_stride,
|
||||||
|
const int64_t value_stride, const int token_id,
|
||||||
|
const int block_table_stride, const int head_num, const int head_dim,
|
||||||
|
const int kv_head_num, const int block_size, const int half_head_dim,
|
||||||
|
const int shard_block_size) {
|
||||||
|
const int seq_len = sequence_lengths[token_id] - 1;
|
||||||
|
const int block_offset = seq_len % block_size;
|
||||||
|
const int block_id =
|
||||||
|
block_tables[token_id * block_table_stride + seq_len / block_size];
|
||||||
|
|
||||||
|
if (block_id < 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
scalar_t x[VecSize];
|
||||||
|
scalar_t y[VecSize];
|
||||||
|
scalar_t out_x[VecSize];
|
||||||
|
scalar_t out_y[VecSize];
|
||||||
|
|
||||||
|
for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim;
|
||||||
|
i += blockDim.x * VecSize) {
|
||||||
|
const int head_offset = i % half_head_dim;
|
||||||
|
const int shard_offset =
|
||||||
|
(head_offset / shard_block_size) * shard_block_size +
|
||||||
|
(head_offset % shard_block_size) / VecSize;
|
||||||
|
const int64_t addr_offset =
|
||||||
|
token_id * key_stride + (i / half_head_dim) * head_dim + head_offset;
|
||||||
|
const int64_t target_id = block_id * head_num * head_dim * block_size +
|
||||||
|
(i / half_head_dim) * block_size * head_dim +
|
||||||
|
block_offset * head_dim + head_offset;
|
||||||
|
|
||||||
|
copy_vector<scalar_t, VecSize>(x, key + addr_offset);
|
||||||
|
copy_vector<scalar_t, VecSize>(y, key + addr_offset + half_head_dim);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < VecSize; j++) {
|
||||||
|
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] -
|
||||||
|
y[j] * sin_ptr[j * 32 + shard_offset];
|
||||||
|
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] +
|
||||||
|
x[j] * sin_ptr[j * 32 + shard_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x);
|
||||||
|
copy_vector<scalar_t, VecSize>(key_cache + target_id + half_head_dim,
|
||||||
|
out_y);
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply value memcopy
|
||||||
|
apply_kv_memcopy<scalar_t, VecSize>(
|
||||||
|
value, value_cache, value_stride, token_id, block_id, head_num * head_dim,
|
||||||
|
block_size, block_offset, head_dim, half_head_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t, int VecSize>
|
||||||
|
__global__ void rotary_embedding_and_cache_copy_kernel(
|
||||||
|
scalar_t* __restrict__ query,
|
||||||
|
scalar_t* __restrict__ key,
|
||||||
|
scalar_t* __restrict__ value,
|
||||||
|
const scalar_t* __restrict__ cos,
|
||||||
|
const scalar_t* __restrict__ sin,
|
||||||
|
scalar_t* __restrict__ key_cache,
|
||||||
|
scalar_t* __restrict__ value_cache,
|
||||||
|
const int* __restrict__ sequence_lengths,
|
||||||
|
const int* __restrict__ block_tables,
|
||||||
|
const int64_t query_stride,
|
||||||
|
const int64_t key_stride,
|
||||||
|
const int64_t value_stride,
|
||||||
|
const int64_t half_shard_element_num,
|
||||||
|
const int cos_stride,
|
||||||
|
const int sin_stride,
|
||||||
|
const int block_table_stride,
|
||||||
|
const int head_num,
|
||||||
|
const int head_dim,
|
||||||
|
const int kv_head_num,
|
||||||
|
const int block_size
|
||||||
|
) {
|
||||||
|
|
||||||
|
const int token_id = blockIdx.x;
|
||||||
|
const int half_head_dim = head_dim / 2;
|
||||||
|
const int shard_block_size = VecSize * 32;
|
||||||
|
|
||||||
|
extern __shared__ char shard_ptr[];
|
||||||
|
|
||||||
|
scalar_t *cos_ptr = (scalar_t*)shard_ptr;
|
||||||
|
scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||||
|
|
||||||
|
// apply cos_sin memcopy
|
||||||
|
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
//compute query
|
||||||
|
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||||
|
|
||||||
|
//compute key and copy kv
|
||||||
|
apply_k_rotary_emb_compute<scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t, int VecSize>
|
||||||
|
__global__ void rotary_embedding_kernel(
|
||||||
|
scalar_t* __restrict__ query,
|
||||||
|
scalar_t* __restrict__ key,
|
||||||
|
const scalar_t* __restrict__ cos,
|
||||||
|
const scalar_t* __restrict__ sin,
|
||||||
|
const int64_t query_stride,
|
||||||
|
const int64_t key_stride,
|
||||||
|
const int64_t half_shard_element_num,
|
||||||
|
const int cos_stride,
|
||||||
|
const int sin_stride,
|
||||||
|
const int head_num,
|
||||||
|
const int head_dim,
|
||||||
|
const int kv_head_num
|
||||||
|
) {
|
||||||
|
const int token_id = blockIdx.x;
|
||||||
|
const int half_head_dim = head_dim / 2;
|
||||||
|
const int shard_block_size = VecSize * 32;
|
||||||
|
|
||||||
|
extern __shared__ char shard_ptr[];
|
||||||
|
|
||||||
|
scalar_t *cos_ptr = (scalar_t*)shard_ptr;
|
||||||
|
scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||||
|
|
||||||
|
// apply cos_sin memcopy
|
||||||
|
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
//compute query
|
||||||
|
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||||
|
|
||||||
|
//compute key
|
||||||
|
apply_emb_rotary_compute<scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
void apply_rotary_embedding_and_cache_copy(
|
||||||
|
at::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||||
|
at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||||
|
at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
|
||||||
|
at::Tensor& cos, // [num_tokens, head_dim]
|
||||||
|
at::Tensor& sin, // [num_tokens, head_dim]
|
||||||
|
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
|
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
|
at::Tensor& sequence_lengths, // [batch_size]
|
||||||
|
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||||
|
{
|
||||||
|
int num_tokens = query.size(0);
|
||||||
|
int head_num = query.size(1);
|
||||||
|
int head_dim = query.size(2);
|
||||||
|
int kv_head_num = key.size(1);
|
||||||
|
int block_size = key_cache.size(2);
|
||||||
|
|
||||||
|
int64_t query_stride = query.stride(0);
|
||||||
|
int64_t key_stride = key.stride(0);
|
||||||
|
int64_t value_stride = value.stride(0);
|
||||||
|
int cos_stride = cos.stride(0);
|
||||||
|
int sin_stride = sin.stride(0);
|
||||||
|
int block_table_stride = block_tables.stride(0);
|
||||||
|
|
||||||
|
int vec_size = get_vec_size<scalar_t>(query);
|
||||||
|
|
||||||
|
if ((head_dim / 2) % vec_size != 0) {
|
||||||
|
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
|
||||||
|
vec_size = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
int thread_nums = head_num * head_dim / vec_size / 2;
|
||||||
|
const int shard_block_size = vec_size * 32 * 2;
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(thread_nums, 512));
|
||||||
|
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ;
|
||||||
|
|
||||||
|
switch (vec_size) {
|
||||||
|
case 1:
|
||||||
|
rotary_embedding_and_cache_copy_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
value.data_ptr<scalar_t>(),
|
||||||
|
cos.data_ptr<scalar_t>(),
|
||||||
|
sin.data_ptr<scalar_t>(),
|
||||||
|
key_cache.data_ptr<scalar_t>(),
|
||||||
|
value_cache.data_ptr<scalar_t>(),
|
||||||
|
sequence_lengths.data_ptr<int>(),
|
||||||
|
block_tables.data_ptr<int>(),
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
value_stride,
|
||||||
|
shard_element_num / 2,
|
||||||
|
cos_stride,
|
||||||
|
sin_stride,
|
||||||
|
block_table_stride,
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
kv_head_num,
|
||||||
|
block_size
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
rotary_embedding_and_cache_copy_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
value.data_ptr<scalar_t>(),
|
||||||
|
cos.data_ptr<scalar_t>(),
|
||||||
|
sin.data_ptr<scalar_t>(),
|
||||||
|
key_cache.data_ptr<scalar_t>(),
|
||||||
|
value_cache.data_ptr<scalar_t>(),
|
||||||
|
sequence_lengths.data_ptr<int>(),
|
||||||
|
block_tables.data_ptr<int>(),
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
value_stride,
|
||||||
|
shard_element_num / 2,
|
||||||
|
cos_stride,
|
||||||
|
sin_stride,
|
||||||
|
block_table_stride,
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
kv_head_num,
|
||||||
|
block_size
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
rotary_embedding_and_cache_copy_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
value.data_ptr<scalar_t>(),
|
||||||
|
cos.data_ptr<scalar_t>(),
|
||||||
|
sin.data_ptr<scalar_t>(),
|
||||||
|
key_cache.data_ptr<scalar_t>(),
|
||||||
|
value_cache.data_ptr<scalar_t>(),
|
||||||
|
sequence_lengths.data_ptr<int>(),
|
||||||
|
block_tables.data_ptr<int>(),
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
value_stride,
|
||||||
|
shard_element_num / 2,
|
||||||
|
cos_stride,
|
||||||
|
sin_stride,
|
||||||
|
block_table_stride,
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
kv_head_num,
|
||||||
|
block_size
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
AT_ERROR("Unsupported vectorized size ", vec_size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
void apply_rotary_embedding(
|
||||||
|
at::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||||
|
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||||
|
at::Tensor& cos, // [total_tokens, head_dim]
|
||||||
|
at::Tensor& sin // [total_tokens, head_dim]
|
||||||
|
){
|
||||||
|
int num_tokens = query.size(0);
|
||||||
|
int head_num = query.size(1);
|
||||||
|
int head_dim = query.size(2);
|
||||||
|
int kv_head_num = key.size(1);
|
||||||
|
|
||||||
|
int query_stride = query.stride(0);
|
||||||
|
int key_stride = key.stride(0);
|
||||||
|
int cos_stride = cos.stride(0);
|
||||||
|
int sin_stride = sin.stride(0);
|
||||||
|
|
||||||
|
int vec_size = get_vec_size<scalar_t>(query);
|
||||||
|
|
||||||
|
if ((head_dim / 2) % vec_size != 0) {
|
||||||
|
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
|
||||||
|
vec_size = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
int thread_nums = head_num * head_dim / vec_size / 2;
|
||||||
|
const int shard_block_size = vec_size * 32 * 2;
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(thread_nums, 512));
|
||||||
|
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ;
|
||||||
|
|
||||||
|
switch (vec_size) {
|
||||||
|
case 1:
|
||||||
|
rotary_embedding_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
cos.data_ptr<scalar_t>(),
|
||||||
|
sin.data_ptr<scalar_t>(),
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
shard_element_num / 2,
|
||||||
|
cos_stride,
|
||||||
|
sin_stride,
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
kv_head_num
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
rotary_embedding_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
cos.data_ptr<scalar_t>(),
|
||||||
|
sin.data_ptr<scalar_t>(),
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
shard_element_num / 2,
|
||||||
|
cos_stride,
|
||||||
|
sin_stride,
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
kv_head_num
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
rotary_embedding_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
cos.data_ptr<scalar_t>(),
|
||||||
|
sin.data_ptr<scalar_t>(),
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
shard_element_num / 2,
|
||||||
|
cos_stride,
|
||||||
|
sin_stride,
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
kv_head_num
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
AT_ERROR("Unsupported vectorized size ", vec_size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
void rotary_embedding_and_cache_copy(
|
||||||
|
at::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||||
|
at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||||
|
at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
|
||||||
|
at::Tensor& cos, // [num_tokens, head_dim]
|
||||||
|
at::Tensor& sin, // [num_tokens, head_dim]
|
||||||
|
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
|
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
|
at::Tensor& sequence_lengths, // [batch_size]
|
||||||
|
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||||
|
{
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
query.scalar_type(),
|
||||||
|
"rotary_embedding_and_cache_copy",
|
||||||
|
apply_rotary_embedding_and_cache_copy<scalar_t>(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
sequence_lengths,
|
||||||
|
block_tables
|
||||||
|
);)
|
||||||
|
}
|
||||||
|
|
||||||
|
void rotary_embedding(
|
||||||
|
at::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||||
|
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||||
|
at::Tensor& cos, // [total_tokens, head_dim]
|
||||||
|
at::Tensor& sin // [total_tokens, head_dim]
|
||||||
|
){
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
query.scalar_type(),
|
||||||
|
"rotary_embedding",
|
||||||
|
apply_rotary_embedding<scalar_t>(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
cos,
|
||||||
|
sin
|
||||||
|
);)
|
||||||
|
}
|
@ -9,6 +9,23 @@ void decode_kv_cache_memcpy(
|
|||||||
torch::Tensor& sequence_lengths, // [batch_size]
|
torch::Tensor& sequence_lengths, // [batch_size]
|
||||||
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
||||||
|
|
||||||
|
void rotary_embedding(
|
||||||
|
torch::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||||
|
torch::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||||
|
torch::Tensor& cos, // [total_tokens, head_dim]
|
||||||
|
torch::Tensor& sin); // [total_tokens, head_dim]
|
||||||
|
|
||||||
|
void rotary_embedding_and_cache_copy(
|
||||||
|
torch::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||||
|
torch::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||||
|
torch::Tensor& value, // [num_tokens, num_heads, head_dim]
|
||||||
|
torch::Tensor& cos, // [num_tokens, head_dim]
|
||||||
|
torch::Tensor& sin, // [num_tokens, head_dim]
|
||||||
|
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||||
|
torch::Tensor&
|
||||||
|
value_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||||
|
torch::Tensor& sequence_lengths, // [batch_size]
|
||||||
|
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
||||||
torch::Tensor silu_and_mul(const torch::Tensor& ins);
|
torch::Tensor silu_and_mul(const torch::Tensor& ins);
|
||||||
|
|
||||||
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
|
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
|
||||||
@ -25,6 +42,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||||
"Copy the GPU memory of kvcache during the decode stage.");
|
"Copy the GPU memory of kvcache during the decode stage.");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy,
|
||||||
|
"performing Rotary Embedding-related calculations and KVCache Memcopy.");
|
||||||
|
|
||||||
|
m.def("rotary_embedding", &rotary_embedding,
|
||||||
|
"performing Rotary Embedding-related calculations.");
|
||||||
|
|
||||||
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
|
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
|
||||||
|
|
||||||
m.def("rms_layernorm", &rms_layernorm,
|
m.def("rms_layernorm", &rms_layernorm,
|
||||||
|
@ -12,6 +12,7 @@ class InferenceOpsCudaExtension(_CudaExtension):
|
|||||||
for fname in [
|
for fname in [
|
||||||
"cuda/pybind/inference.cpp",
|
"cuda/pybind/inference.cpp",
|
||||||
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
||||||
|
"cuda/fused_rotary_emb_and_cache_kernel.cu",
|
||||||
"cuda/activation_kernel.cu",
|
"cuda/activation_kernel.cu",
|
||||||
"cuda/rms_layernorm_kernel.cu",
|
"cuda/rms_layernorm_kernel.cu",
|
||||||
]
|
]
|
||||||
|
@ -22,15 +22,11 @@ def setup_seed(seed):
|
|||||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||||
setup_seed(20)
|
setup_seed(20)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
model = (
|
model = LlamaForCausalLM(
|
||||||
LlamaForCausalLM(
|
|
||||||
LlamaConfig(
|
LlamaConfig(
|
||||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||||
)
|
)
|
||||||
)
|
).cuda()
|
||||||
.cuda()
|
|
||||||
.half()
|
|
||||||
)
|
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
@ -44,7 +40,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
|||||||
top_k = 50
|
top_k = 50
|
||||||
|
|
||||||
if use_engine:
|
if use_engine:
|
||||||
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
|
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
|
||||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||||
inference_engine.add_request(prompts=inputs)
|
inference_engine.add_request(prompts=inputs)
|
||||||
|
91
tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py
Normal file
91
tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||||
|
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
|
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
|
||||||
|
from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("BATCH_SIZE", [4])
|
||||||
|
@pytest.mark.parametrize("SEQ_LEN", [64])
|
||||||
|
@pytest.mark.parametrize("H", [32])
|
||||||
|
@pytest.mark.parametrize("D", [64])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||||
|
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
||||||
|
torch.manual_seed(10)
|
||||||
|
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
|
||||||
|
# our crafted op equals to Transformers
|
||||||
|
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
|
||||||
|
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
|
||||||
|
emb = LlamaRotaryEmbedding(D)
|
||||||
|
cos, sin = emb(x0, TOTAL_TOKENS)
|
||||||
|
cos_2 = cos[:, : D // 2]
|
||||||
|
sin_2 = sin[:, : D // 2]
|
||||||
|
position_ids = torch.arange(TOTAL_TOKENS)
|
||||||
|
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
|
||||||
|
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
|
||||||
|
assert torch.allclose(embd_x0, embd_stimulated_x)
|
||||||
|
|
||||||
|
# create data
|
||||||
|
block_size = 32
|
||||||
|
max_blocks_per_sequence = (TOTAL_TOKENS + block_size - 1) // block_size
|
||||||
|
q_shape = (TOTAL_TOKENS, H, D)
|
||||||
|
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||||
|
k_shape = (TOTAL_TOKENS, H, D)
|
||||||
|
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
||||||
|
cos_shape = (TOTAL_TOKENS, D // 2)
|
||||||
|
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
cache_shape = (BATCH_SIZE * max_blocks_per_sequence, H, block_size, D)
|
||||||
|
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||||
|
v = torch.randn_like(k)
|
||||||
|
v_cache = torch.zeros_like(k_cache)
|
||||||
|
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
||||||
|
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||||
|
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size
|
||||||
|
)
|
||||||
|
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
|
||||||
|
new_q = torch.randn_like(new_k)
|
||||||
|
new_v = torch.randn_like(new_k)
|
||||||
|
|
||||||
|
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||||
|
block_tables = block_tables.to(device="cuda")
|
||||||
|
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
||||||
|
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
||||||
|
|
||||||
|
new_q_copy = new_q.clone()
|
||||||
|
new_k_copy = new_k.clone()
|
||||||
|
|
||||||
|
inference_ops.rotary_embedding_and_cache_copy(
|
||||||
|
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
|
||||||
|
)
|
||||||
|
|
||||||
|
inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin)
|
||||||
|
|
||||||
|
past_kv_seq_len = kv_seq_lengths - 1
|
||||||
|
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
||||||
|
offsets_in_block = past_kv_seq_len % block_size
|
||||||
|
k_target = k_cache[target_block_ids, :, offsets_in_block, :].squeeze()
|
||||||
|
k_source = new_k_copy.squeeze()
|
||||||
|
v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze()
|
||||||
|
v_source = new_v.squeeze()
|
||||||
|
|
||||||
|
assert torch.allclose(new_q, q_ref, atol=1e-6, rtol=1e-6)
|
||||||
|
assert torch.allclose(k_target, k_ref, atol=1e-6, rtol=1e-6)
|
||||||
|
|
||||||
|
assert torch.allclose(new_q_copy, q_ref, atol=1e-6, rtol=1e-6)
|
||||||
|
assert torch.allclose(new_k_copy, k_ref, atol=1e-6, rtol=1e-6)
|
||||||
|
|
||||||
|
assert k_target.shape == k_source.shape
|
||||||
|
assert torch.allclose(k_target, k_source, atol=1e-6, rtol=1e-6)
|
||||||
|
|
||||||
|
assert v_target.shape == v_source.shape
|
||||||
|
assert torch.equal(v_target, v_source)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_rotary_emb(16, 512, 4, 128, torch.float16)
|
Loading…
Reference in New Issue
Block a user