mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 04:02:17 +00:00
The writing style of tail processing and the logic related to macro definitions have been optimized. (#5519)
This commit is contained in:
parent
e6496dd371
commit
934e31afb2
@ -2,7 +2,7 @@ ROOT=$(realpath $(dirname $0))
|
|||||||
echo $ROOT
|
echo $ROOT
|
||||||
PY_SCRIPT=${ROOT}/benchmark_llama.py
|
PY_SCRIPT=${ROOT}/benchmark_llama.py
|
||||||
GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
|
GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
|
||||||
mode="colossalai"
|
mode=$1
|
||||||
|
|
||||||
mkdir -p logs
|
mkdir -p logs
|
||||||
|
|
||||||
|
@ -56,21 +56,14 @@
|
|||||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
|
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
|
||||||
TYPE, NAME, ...) \
|
TYPE, NAME, ...) \
|
||||||
switch (HIGH_PRECISION) { \
|
if (HIGH_PRECISION) { \
|
||||||
case false: { \
|
const bool high_precision = true; \
|
||||||
const bool high_precision = false; \
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
|
||||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
|
} else { \
|
||||||
break; \
|
const bool high_precision = false; \
|
||||||
} \
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
|
||||||
case true: { \
|
|
||||||
const bool high_precision = true; \
|
|
||||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
|
|
||||||
break; \
|
|
||||||
} \
|
|
||||||
default: \
|
|
||||||
AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
||||||
|
@ -27,17 +27,11 @@ struct MPTypeTrait<at::BFloat16> {
|
|||||||
using Type = float;
|
using Type = float;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <bool high_precision, typename scalar_t>
|
template <bool high_precision, typename T>
|
||||||
struct ScalarTypeTrait;
|
struct ScalarTypeTrait {
|
||||||
|
using Type =
|
||||||
template <typename T>
|
typename std::conditional<high_precision, typename MPTypeTrait<T>::Type,
|
||||||
struct ScalarTypeTrait<true, T> {
|
T>::type;
|
||||||
using Type = typename MPTypeTrait<T>::Type;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct ScalarTypeTrait<false, T> {
|
|
||||||
using Type = T;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
#include "utils/vector_copy_utils.h"
|
#include "utils/vector_copy_utils.h"
|
||||||
#include "../common/micros.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
template<typename scalar_t, int VecSize>
|
template<typename scalar_t, bool Aligned, int VecSize>
|
||||||
__global__ void context_kv_cache_memcpy_kernel(
|
__global__ void context_kv_cache_memcpy_kernel(
|
||||||
const scalar_t* __restrict__ key,
|
const scalar_t* __restrict__ key,
|
||||||
const scalar_t* __restrict__ value,
|
const scalar_t* __restrict__ value,
|
||||||
@ -55,17 +55,19 @@ __global__ void context_kv_cache_memcpy_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// tail process
|
// tail process
|
||||||
for (; i < hidden_size; ++i ) {
|
if (!Aligned) {
|
||||||
head_id = i / head_dim;
|
for (; i < hidden_size; ++i ) {
|
||||||
head_offset = i % head_dim;
|
head_id = i / head_dim;
|
||||||
key_src_id = total_token_id * key_stride + i;
|
head_offset = i % head_dim;
|
||||||
value_src_id = total_token_id * value_stride + i;
|
key_src_id = total_token_id * key_stride + i;
|
||||||
target_id = block_id * hidden_size * block_size
|
value_src_id = total_token_id * value_stride + i;
|
||||||
+ head_id * block_size * head_dim
|
target_id = block_id * hidden_size * block_size
|
||||||
+ block_offset * head_dim + head_offset;
|
+ head_id * block_size * head_dim
|
||||||
|
+ block_offset * head_dim + head_offset;
|
||||||
|
|
||||||
key_cache[target_id] = key[key_src_id];
|
key_cache[target_id] = key[key_src_id];
|
||||||
value_cache[target_id] = value[value_src_id];
|
value_cache[target_id] = value[value_src_id];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -93,76 +95,61 @@ void apply_context_kv_cache_memcpy(
|
|||||||
|
|
||||||
int vec_size = get_vec_size<scalar_t>(key);
|
int vec_size = get_vec_size<scalar_t>(key);
|
||||||
|
|
||||||
|
bool aligned = true;
|
||||||
if (head_dim % vec_size != 0) {
|
if (head_dim % vec_size != 0) {
|
||||||
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
|
aligned = false;
|
||||||
vec_size = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int thread_nums = head_num * head_dim / vec_size;
|
int thread_nums = head_num * head_dim / vec_size;
|
||||||
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
dim3 grid(max_seq_len_in_batch, batch_size);
|
dim3 grid(max_seq_len_in_batch, batch_size);
|
||||||
dim3 block(std::min(thread_nums, 512));
|
dim3 block(std::min(thread_nums, 512));
|
||||||
|
|
||||||
switch (vec_size) {
|
#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \
|
||||||
case 1:
|
do { \
|
||||||
context_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
context_kv_cache_memcpy_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(), \
|
||||||
value.data_ptr<scalar_t>(),
|
value.data_ptr<scalar_t>(), \
|
||||||
key_cache.data_ptr<scalar_t>(),
|
key_cache.data_ptr<scalar_t>(), \
|
||||||
value_cache.data_ptr<scalar_t>(),
|
value_cache.data_ptr<scalar_t>(), \
|
||||||
sequence_lengths.data_ptr<int>(),
|
sequence_lengths.data_ptr<int>(), \
|
||||||
cu_seqlens.data_ptr<int>(),
|
cu_seqlens.data_ptr<int>(), \
|
||||||
block_tables.data_ptr<int>(),
|
block_tables.data_ptr<int>(), \
|
||||||
head_num,
|
head_num, \
|
||||||
head_dim,
|
head_dim, \
|
||||||
block_size,
|
block_size, \
|
||||||
batch_size,
|
batch_size, \
|
||||||
block_table_stride,
|
block_table_stride, \
|
||||||
key_stride,
|
key_stride, \
|
||||||
value_stride
|
value_stride \
|
||||||
);
|
); \
|
||||||
break;
|
} while(0)
|
||||||
case 2:
|
|
||||||
context_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned) \
|
||||||
key.data_ptr<scalar_t>(),
|
do { \
|
||||||
value.data_ptr<scalar_t>(),
|
switch (vec_size) { \
|
||||||
key_cache.data_ptr<scalar_t>(),
|
case 1: \
|
||||||
value_cache.data_ptr<scalar_t>(),
|
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \
|
||||||
sequence_lengths.data_ptr<int>(),
|
break; \
|
||||||
cu_seqlens.data_ptr<int>(),
|
case 2: \
|
||||||
block_tables.data_ptr<int>(),
|
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \
|
||||||
head_num,
|
break; \
|
||||||
head_dim,
|
case 4: \
|
||||||
block_size,
|
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \
|
||||||
batch_size,
|
break; \
|
||||||
block_table_stride,
|
default: \
|
||||||
key_stride,
|
AT_ERROR("Unsupported vectorized size ", vec_size); \
|
||||||
value_stride
|
break; \
|
||||||
);
|
} \
|
||||||
break;
|
} while(0)
|
||||||
case 4:
|
|
||||||
context_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
|
||||||
key.data_ptr<scalar_t>(),
|
if (aligned) {
|
||||||
value.data_ptr<scalar_t>(),
|
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true);
|
||||||
key_cache.data_ptr<scalar_t>(),
|
}
|
||||||
value_cache.data_ptr<scalar_t>(),
|
else {
|
||||||
sequence_lengths.data_ptr<int>(),
|
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false);
|
||||||
cu_seqlens.data_ptr<int>(),
|
|
||||||
block_tables.data_ptr<int>(),
|
|
||||||
head_num,
|
|
||||||
head_dim,
|
|
||||||
block_size,
|
|
||||||
batch_size,
|
|
||||||
block_table_stride,
|
|
||||||
key_stride,
|
|
||||||
value_stride
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
AT_ERROR("Unsupported vectorized size ", vec_size);
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
#include "utils/vector_copy_utils.h"
|
#include "utils/vector_copy_utils.h"
|
||||||
#include "../common/micros.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
template<typename scalar_t, int VecSize>
|
template<typename scalar_t, bool Aligned, int VecSize>
|
||||||
__global__ void decode_kv_cache_memcpy_kernel(
|
__global__ void decode_kv_cache_memcpy_kernel(
|
||||||
const scalar_t* __restrict__ key,
|
const scalar_t* __restrict__ key,
|
||||||
const scalar_t* __restrict__ value,
|
const scalar_t* __restrict__ value,
|
||||||
@ -45,17 +45,19 @@ __global__ void decode_kv_cache_memcpy_kernel(
|
|||||||
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
|
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (; i < hidden_size; ++i ) {
|
if (!Aligned) {
|
||||||
const int head_id = i / head_dim;
|
for (; i < hidden_size; ++i ) {
|
||||||
const int head_offset = i % head_dim;
|
const int head_id = i / head_dim;
|
||||||
const int64_t key_src_id = seq_id * key_stride + i;
|
const int head_offset = i % head_dim;
|
||||||
const int64_t value_src_id = seq_id * value_stride + i;
|
const int64_t key_src_id = seq_id * key_stride + i;
|
||||||
const int64_t target_id = block_id * hidden_size * block_size
|
const int64_t value_src_id = seq_id * value_stride + i;
|
||||||
+ head_id * block_size * head_dim
|
const int64_t target_id = block_id * hidden_size * block_size
|
||||||
+ block_offset * head_dim + head_offset;
|
+ head_id * block_size * head_dim
|
||||||
|
+ block_offset * head_dim + head_offset;
|
||||||
|
|
||||||
key_cache[target_id] = key[key_src_id];
|
key_cache[target_id] = key[key_src_id];
|
||||||
value_cache[target_id] = value[value_src_id];
|
value_cache[target_id] = value[value_src_id];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -80,70 +82,58 @@ void apply_decode_kv_cache_memcpy(
|
|||||||
|
|
||||||
int vec_size = get_vec_size<scalar_t>(key);
|
int vec_size = get_vec_size<scalar_t>(key);
|
||||||
|
|
||||||
|
bool aligned = true;
|
||||||
if (head_dim % vec_size != 0) {
|
if (head_dim % vec_size != 0) {
|
||||||
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
|
aligned = false;
|
||||||
vec_size = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int thread_nums = head_num * head_dim / vec_size;
|
int thread_nums = head_num * head_dim / vec_size;
|
||||||
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(thread_nums, 512));
|
dim3 block(std::min(thread_nums, 512));
|
||||||
|
|
||||||
switch (vec_size) {
|
#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \
|
||||||
case 1:
|
do { \
|
||||||
decode_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
decode_kv_cache_memcpy_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(), \
|
||||||
value.data_ptr<scalar_t>(),
|
value.data_ptr<scalar_t>(), \
|
||||||
key_cache.data_ptr<scalar_t>(),
|
key_cache.data_ptr<scalar_t>(), \
|
||||||
value_cache.data_ptr<scalar_t>(),
|
value_cache.data_ptr<scalar_t>(), \
|
||||||
sequence_lengths.data_ptr<int>(),
|
sequence_lengths.data_ptr<int>(), \
|
||||||
block_tables.data_ptr<int>(),
|
block_tables.data_ptr<int>(), \
|
||||||
head_num,
|
head_num, \
|
||||||
head_dim,
|
head_dim, \
|
||||||
block_size,
|
block_size, \
|
||||||
key_stride,
|
key_stride, \
|
||||||
value_stride,
|
value_stride, \
|
||||||
block_table_stride
|
block_table_stride \
|
||||||
);
|
); \
|
||||||
break;
|
} while(0)
|
||||||
case 2:
|
|
||||||
decode_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned, __vec_size) \
|
||||||
key.data_ptr<scalar_t>(),
|
do { \
|
||||||
value.data_ptr<scalar_t>(),
|
switch (__vec_size) { \
|
||||||
key_cache.data_ptr<scalar_t>(),
|
case 1: \
|
||||||
value_cache.data_ptr<scalar_t>(),
|
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \
|
||||||
sequence_lengths.data_ptr<int>(),
|
break; \
|
||||||
block_tables.data_ptr<int>(),
|
case 2: \
|
||||||
head_num,
|
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \
|
||||||
head_dim,
|
break; \
|
||||||
block_size,
|
case 4: \
|
||||||
key_stride,
|
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \
|
||||||
value_stride,
|
break; \
|
||||||
block_table_stride
|
default: \
|
||||||
);
|
AT_ERROR("Unsupported vectorized size ", __vec_size); \
|
||||||
break;
|
break; \
|
||||||
case 4:
|
} \
|
||||||
decode_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
} while(0)
|
||||||
key.data_ptr<scalar_t>(),
|
|
||||||
value.data_ptr<scalar_t>(),
|
if (aligned) {
|
||||||
key_cache.data_ptr<scalar_t>(),
|
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true, vec_size);
|
||||||
value_cache.data_ptr<scalar_t>(),
|
}
|
||||||
sequence_lengths.data_ptr<int>(),
|
else {
|
||||||
block_tables.data_ptr<int>(),
|
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false, vec_size);
|
||||||
head_num,
|
|
||||||
head_dim,
|
|
||||||
block_size,
|
|
||||||
key_stride,
|
|
||||||
value_stride,
|
|
||||||
block_table_stride
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
AT_ERROR("Unsupported vectorized size ", vec_size);
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
Loading…
Reference in New Issue
Block a user