[Inference/Refactor] Refactor compilation mechanism and unified multi hw (#5613)

* refactor compilation mechanism and unified multi hw

* fix file path bug

* add init.py to make pybind a module to avoid relative path error caused by softlink

* delete duplicated micros

* fix micros bug in gcc
This commit is contained in:
傅剑寒
2024-04-24 14:17:54 +08:00
committed by GitHub
parent 04863a9b14
commit 279300dc5f
64 changed files with 345 additions and 310 deletions

View File

@@ -0,0 +1,3 @@
from .inference_ops_cuda import InferenceOpsCudaExtension
__all__ = ["InferenceOpsCudaExtension"]

View File

@@ -0,0 +1,103 @@
#include <torch/extension.h>
void decode_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor&
value_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables); // [batch_size, max_seq_len]
void context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1]
at::Tensor& block_tables, // [batch_size, max_seq_len]
int max_seq_len_in_batch);
void rotary_embedding(
torch::Tensor& query, // [total_tokens, head_num, head_dim]
torch::Tensor& key, // [total_tokens, kv_head_num, head_dim]
torch::Tensor& cos, // [total_tokens, head_dim]
torch::Tensor& sin, // [total_tokens, head_dim]
bool high_precision);
void rotary_embedding_and_cache_copy(
torch::Tensor& query, // [num_tokens, head_num, head_dim]
torch::Tensor& key, // [num_tokens, kv_head_num, head_dim]
torch::Tensor& value, // [num_tokens, num_heads, head_dim]
torch::Tensor& cos, // [num_tokens, head_dim]
torch::Tensor& sin, // [num_tokens, head_dim]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim]
torch::Tensor&
value_cache, // [num_blocks, num_heads, block_size, head_dim]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables, // [batch_size, max_seq_len]
bool high_precision);
torch::Tensor silu_and_mul(const torch::Tensor& ins);
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon);
void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon);
void get_cos_and_sin(at::Tensor& cos_cache, // [max_rotary_position, head_dim]
at::Tensor& sin_cache, // [max_rotary_position, head_dim]
at::Tensor& cos, // [num_tokens, head_dim]
at::Tensor& sin, // [num_tokens, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
int max_seq_len_in_batch, bool is_prompts);
void flash_decoding_attention(
torch::Tensor& out, // [num_tokens, num_heads, head_size]
torch::Tensor& query, // [num_tokens, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor&
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor& context_lens, // [num_tokens]
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
int block_size, int max_context_len,
torch::Tensor&
tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
float scale);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the decode stage.");
m.def("context_kv_cache_memcpy", &context_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the context stage.");
m.def(
"rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy,
"Performing Rotary Embedding-related calculations and KVCache Memcopy.");
m.def("rotary_embedding", &rotary_embedding,
"Performing Rotary Embedding-related calculations.");
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
m.def("rms_layernorm", &rms_layernorm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm,
"In-place fused Add and RMS Normalization.");
m.def("get_cos_and_sin", &get_cos_and_sin, "Get cos and sin from the cache.");
m.def("flash_decoding_attention", &flash_decoding_attention,
"Compute the attention between an input query and the cached "
"keys/values using PagedAttention.");
}

View File

@@ -0,0 +1,31 @@
from ...cuda_extension import _CudaExtension
from ...utils import get_cuda_cc_flag
class InferenceOpsCudaExtension(_CudaExtension):
def __init__(self):
super().__init__(name="inference_ops_cuda")
def sources_files(self):
ret = [
self.csrc_abs_path(fname)
for fname in [
"kernel/cuda/decode_kv_cache_memcpy_kernel.cu",
"kernel/cuda/context_kv_cache_memcpy_kernel.cu",
"kernel/cuda/fused_rotary_emb_and_cache_kernel.cu",
"kernel/cuda/activation_kernel.cu",
"kernel/cuda/rms_layernorm_kernel.cu",
"kernel/cuda/get_cos_and_sin_kernel.cu",
"kernel/cuda/flash_decoding_attention_kernel.cu",
]
] + [self.pybind_abs_path("inference/inference.cpp")]
return ret
def cxx_flags(self):
version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
return ["-O3"] + version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = ["-lineinfo"]
extra_cuda_flags.extend(get_cuda_cc_flag())
return ["-O3", "--use_fast_math"] + extra_cuda_flags + super().nvcc_flags()