mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[Inference/Feat] Feat quant kvcache step2 (#5674)
This commit is contained in:
@@ -4,16 +4,17 @@
|
||||
#include "utils/vec_copy.h"
|
||||
#include "common/micros.h"
|
||||
|
||||
using colossalAI::cuda::utils::copy_vector;
|
||||
using colossalAI::cuda::utils::get_vec_size;
|
||||
using colossalAI::cuda::utils::copy;
|
||||
using colossalAI::funcs::CastFunctor;
|
||||
|
||||
|
||||
template<typename scalar_t, bool Aligned, int VecSize>
|
||||
template<typename T, typename CacheT, bool Aligned, int VecSize>
|
||||
__global__ void context_kv_cache_memcpy_kernel(
|
||||
const scalar_t* __restrict__ key,
|
||||
const scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache,
|
||||
scalar_t* __restrict__ value_cache,
|
||||
const T* __restrict__ key,
|
||||
const T* __restrict__ value,
|
||||
CacheT* __restrict__ key_cache,
|
||||
CacheT* __restrict__ value_cache,
|
||||
const int* __restrict__ sequence_lengths,
|
||||
const int* __restrict__ cu_seqlens,
|
||||
const int* __restrict__ block_tables,
|
||||
@@ -54,8 +55,8 @@ __global__ void context_kv_cache_memcpy_kernel(
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
|
||||
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
|
||||
copy<T, CacheT, VecSize>(key + key_src_id, key_cache + target_id);
|
||||
copy<T, CacheT, VecSize>(value + value_src_id, value_cache + target_id);
|
||||
}
|
||||
|
||||
// tail process
|
||||
@@ -69,22 +70,22 @@ __global__ void context_kv_cache_memcpy_kernel(
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
key_cache[target_id] = key[key_src_id];
|
||||
value_cache[target_id] = value[value_src_id];
|
||||
key_cache[target_id] = CastFunctor<T, CacheT>()(key[key_src_id]);
|
||||
value_cache[target_id] = CastFunctor<T, CacheT>()(value[value_src_id]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
template<typename T, typename CacheT>
|
||||
void apply_context_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
at::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
torch::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
torch::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
int max_seq_len_in_batch)
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
@@ -97,7 +98,7 @@ void apply_context_kv_cache_memcpy(
|
||||
int64_t value_stride = value.stride(0);
|
||||
int block_table_stride = block_tables.stride(0);
|
||||
|
||||
int vec_size = get_vec_size<scalar_t>(key);
|
||||
int vec_size = get_vec_size<T>(key);
|
||||
|
||||
bool aligned = true;
|
||||
if (head_dim % vec_size != 0) {
|
||||
@@ -112,11 +113,11 @@ void apply_context_kv_cache_memcpy(
|
||||
|
||||
#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \
|
||||
do { \
|
||||
context_kv_cache_memcpy_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
|
||||
key.data_ptr<scalar_t>(), \
|
||||
value.data_ptr<scalar_t>(), \
|
||||
key_cache.data_ptr<scalar_t>(), \
|
||||
value_cache.data_ptr<scalar_t>(), \
|
||||
context_kv_cache_memcpy_kernel<T, CacheT, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<T*>(key.data_ptr()), \
|
||||
reinterpret_cast<T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CacheT*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CacheT*>(value_cache.data_ptr()), \
|
||||
sequence_lengths.data_ptr<int>(), \
|
||||
cu_seqlens.data_ptr<int>(), \
|
||||
block_tables.data_ptr<int>(), \
|
||||
@@ -161,26 +162,63 @@ void apply_context_kv_cache_memcpy(
|
||||
}
|
||||
|
||||
void context_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
at::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
torch::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
torch::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
int max_seq_len_in_batch)
|
||||
{
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
key.scalar_type(),
|
||||
"context_kv_cache_memcpy",
|
||||
apply_context_kv_cache_memcpy<scalar_t>(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
sequence_lengths,
|
||||
cu_seqlens,
|
||||
block_tables,
|
||||
max_seq_len_in_batch
|
||||
);)
|
||||
|
||||
TORCH_CHECK(key.scalar_type() == at::ScalarType::Float || key.scalar_type() == at::ScalarType::Half || key.scalar_type() == at::ScalarType::BFloat16,
|
||||
"Dtype of key should be float, half or bfloat16!");
|
||||
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key.scalar_type(),
|
||||
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
|
||||
|
||||
|
||||
#define _(T, CacheT) \
|
||||
apply_context_kv_cache_memcpy<T, CacheT>( \
|
||||
key, \
|
||||
value, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
sequence_lengths, \
|
||||
cu_seqlens, \
|
||||
block_tables, \
|
||||
max_seq_len_in_batch \
|
||||
)
|
||||
|
||||
if(key_cache.scalar_type() == at::ScalarType::Byte)
|
||||
{
|
||||
switch (key.scalar_type())
|
||||
{
|
||||
case at::ScalarType::Float:
|
||||
_(float, uint8_t);
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
_(half, uint8_t);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
_(__nv_bfloat16, uint8_t);
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
switch (key.scalar_type())
|
||||
{
|
||||
case at::ScalarType::Float:
|
||||
_(float, float);
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
_(half, half);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
_(__nv_bfloat16, __nv_bfloat16);
|
||||
break;
|
||||
}
|
||||
}
|
||||
#undef _
|
||||
}
|
||||
|
@@ -372,7 +372,7 @@ void flash_decoding_attention(
|
||||
|
||||
TORCH_CHECK(query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16,
|
||||
"Dtype of query should be float, half or bfloat16!");
|
||||
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key_cache.scalar_type(),
|
||||
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == query.scalar_type(),
|
||||
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
|
||||
|
||||
if(key_cache.scalar_type() == at::ScalarType::Byte)
|
||||
|
@@ -11,10 +11,9 @@ namespace colossalAI {
|
||||
namespace cuda {
|
||||
namespace utils {
|
||||
|
||||
template <typename T, int VecSize>
|
||||
template <typename T, int vec_size>
|
||||
__device__ __inline__ void copy_vector(T *dst, const T *src) {
|
||||
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
|
||||
// Note(LiuYang): Here static_cast can't be used for cast between two pointer
|
||||
using VT = typename common::VecTypeTrait<T, vec_size>::Type;
|
||||
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
|
||||
}
|
||||
|
||||
@@ -33,9 +32,33 @@ __device__ __inline__ void copy_zero_vector(T *dst) {
|
||||
*(reinterpret_cast<VT *>(dst)) = funcs::CastFunctor<float, VT>()(0.0f);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, int vec_size>
|
||||
__device__ __inline__ void copy(const SrcT *src, DstT *dst) {
|
||||
using SrcVT = typename common::VecTypeTrait<SrcT, vec_size>::Type;
|
||||
using DstVT = typename common::VecTypeTrait<DstT, vec_size>::Type;
|
||||
// Note(LiuYang): Here static_cast can't be used for cast between two pointer
|
||||
*(reinterpret_cast<DstVT *>(dst)) = funcs::CastFunctor<SrcVT, DstVT>()(
|
||||
*(reinterpret_cast<const SrcVT *>(src)));
|
||||
}
|
||||
|
||||
template <typename T, int vec_size>
|
||||
__device__ __inline__ void copy<T, T, vec_size>(const T *src, T *dst) {
|
||||
using VT = typename common::VecTypeTrait<T, vec_size>::Type;
|
||||
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy<float, float, 8>(const float *src, float *dst) {
|
||||
// Since the maximum memory alignment length is 128 bits, we choose float4
|
||||
// here.
|
||||
*(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<const float4 *>(src));
|
||||
*(reinterpret_cast<float4 *>(dst + 4)) =
|
||||
*(reinterpret_cast<const float4 *>(src + 4));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int get_vec_size(const torch::Tensor &tensor) {
|
||||
uint64_t address = reinterpret_cast<uint64_t>(tensor.data_ptr<T>());
|
||||
uint64_t address = reinterpret_cast<uint64_t>(tensor.data_ptr());
|
||||
const int max_aligned_size = 128;
|
||||
const int dtype_size = sizeof(T) * 8;
|
||||
|
||||
|
Reference in New Issue
Block a user