mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[Inference/Refactor] Refactor compilation mechanism and unified multi hw (#5613)
* refactor compilation mechanism and unified multi hw * fix file path bug * add init.py to make pybind a module to avoid relative path error caused by softlink * delete duplicated micros * fix micros bug in gcc
This commit is contained in:
3
extensions/pybind/inference/__init__.py
Normal file
3
extensions/pybind/inference/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .inference_ops_cuda import InferenceOpsCudaExtension
|
||||
|
||||
__all__ = ["InferenceOpsCudaExtension"]
|
103
extensions/pybind/inference/inference.cpp
Normal file
103
extensions/pybind/inference/inference.cpp
Normal file
@@ -0,0 +1,103 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void decode_kv_cache_memcpy(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
||||
|
||||
void context_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
at::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
int max_seq_len_in_batch);
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||
torch::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||
torch::Tensor& cos, // [total_tokens, head_dim]
|
||||
torch::Tensor& sin, // [total_tokens, head_dim]
|
||||
bool high_precision);
|
||||
|
||||
void rotary_embedding_and_cache_copy(
|
||||
torch::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_dim]
|
||||
torch::Tensor& cos, // [num_tokens, head_dim]
|
||||
torch::Tensor& sin, // [num_tokens, head_dim]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
bool high_precision);
|
||||
|
||||
torch::Tensor silu_and_mul(const torch::Tensor& ins);
|
||||
|
||||
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon);
|
||||
|
||||
void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon);
|
||||
|
||||
void get_cos_and_sin(at::Tensor& cos_cache, // [max_rotary_position, head_dim]
|
||||
at::Tensor& sin_cache, // [max_rotary_position, head_dim]
|
||||
at::Tensor& cos, // [num_tokens, head_dim]
|
||||
at::Tensor& sin, // [num_tokens, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
int max_seq_len_in_batch, bool is_prompts);
|
||||
|
||||
void flash_decoding_attention(
|
||||
torch::Tensor& out, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& query, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor&
|
||||
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
torch::Tensor& context_lens, // [num_tokens]
|
||||
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
|
||||
int block_size, int max_context_len,
|
||||
torch::Tensor&
|
||||
tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
|
||||
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
|
||||
float scale);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||
"Copy the GPU memory of kvcache during the decode stage.");
|
||||
|
||||
m.def("context_kv_cache_memcpy", &context_kv_cache_memcpy,
|
||||
"Copy the GPU memory of kvcache during the context stage.");
|
||||
|
||||
m.def(
|
||||
"rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy,
|
||||
"Performing Rotary Embedding-related calculations and KVCache Memcopy.");
|
||||
|
||||
m.def("rotary_embedding", &rotary_embedding,
|
||||
"Performing Rotary Embedding-related calculations.");
|
||||
|
||||
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
|
||||
|
||||
m.def("rms_layernorm", &rms_layernorm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
|
||||
m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm,
|
||||
"In-place fused Add and RMS Normalization.");
|
||||
|
||||
m.def("get_cos_and_sin", &get_cos_and_sin, "Get cos and sin from the cache.");
|
||||
|
||||
m.def("flash_decoding_attention", &flash_decoding_attention,
|
||||
"Compute the attention between an input query and the cached "
|
||||
"keys/values using PagedAttention.");
|
||||
}
|
31
extensions/pybind/inference/inference_ops_cuda.py
Normal file
31
extensions/pybind/inference/inference_ops_cuda.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from ...cuda_extension import _CudaExtension
|
||||
from ...utils import get_cuda_cc_flag
|
||||
|
||||
|
||||
class InferenceOpsCudaExtension(_CudaExtension):
|
||||
def __init__(self):
|
||||
super().__init__(name="inference_ops_cuda")
|
||||
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
self.csrc_abs_path(fname)
|
||||
for fname in [
|
||||
"kernel/cuda/decode_kv_cache_memcpy_kernel.cu",
|
||||
"kernel/cuda/context_kv_cache_memcpy_kernel.cu",
|
||||
"kernel/cuda/fused_rotary_emb_and_cache_kernel.cu",
|
||||
"kernel/cuda/activation_kernel.cu",
|
||||
"kernel/cuda/rms_layernorm_kernel.cu",
|
||||
"kernel/cuda/get_cos_and_sin_kernel.cu",
|
||||
"kernel/cuda/flash_decoding_attention_kernel.cu",
|
||||
]
|
||||
] + [self.pybind_abs_path("inference/inference.cpp")]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
|
||||
return ["-O3"] + version_dependent_macros
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = ["-lineinfo"]
|
||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||
return ["-O3", "--use_fast_math"] + extra_cuda_flags + super().nvcc_flags()
|
Reference in New Issue
Block a user