mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[Inference/Feat] Add convert_fp8 op for fp8 test in the future (#5706)
* add convert_fp8 op for fp8 test in the future * rerun ci
This commit is contained in:
127
extensions/csrc/kernel/cuda/convert_fp8_kernel.cu
Normal file
127
extensions/csrc/kernel/cuda/convert_fp8_kernel.cu
Normal file
@@ -0,0 +1,127 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "common/micros.h"
|
||||
#include "utils/vec_copy.h"
|
||||
#include "funcs/cast_functor.h"
|
||||
|
||||
|
||||
using colossalAI::cuda::utils::copy;
|
||||
using colossalAI::cuda::utils::get_vec_size;
|
||||
using colossalAI::funcs::CastFunctor;
|
||||
|
||||
template <typename InT, typename OutT, int VecSize>
|
||||
__global__ void convert_fp8_kernel(const InT* ins_data, OutT* outs_data, int numel, int tail)
|
||||
{
|
||||
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
|
||||
const int64_t grid_size = blockDim.x * gridDim.x;
|
||||
if(idx > numel + tail) {
|
||||
return;
|
||||
}
|
||||
|
||||
for(int64_t i = idx; i < numel; i += grid_size) {
|
||||
copy<InT, OutT, VecSize>(ins_data + i * VecSize, outs_data + i * VecSize);
|
||||
}
|
||||
// Tail process
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
for(int i = 0; i < tail; ++i)
|
||||
{
|
||||
outs_data[i + numel * VecSize] = CastFunctor<InT, OutT>()(ins_data[i + numel * VecSize]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT, typename OutT>
|
||||
void apply_convert_fp8(torch::Tensor& input, torch::Tensor& output)
|
||||
{
|
||||
const int kVecSize = get_vec_size<InT>(input);
|
||||
const int kNumel = torch::numel(input);
|
||||
|
||||
const int kVecNumel = (kNumel >> static_cast<int>(std::log2(kVecSize)));
|
||||
const int kTail = kNumel & (kVecSize - 1);
|
||||
int grid_size = kVecNumel ? (kVecNumel + 255) / 256 : 1;
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(grid_size);
|
||||
dim3 block(256);
|
||||
|
||||
#define _(VEC_SIZE) \
|
||||
convert_fp8_kernel<InT, OutT, VEC_SIZE> \
|
||||
<<<grid, block, 0, stream>>> \
|
||||
(reinterpret_cast<const InT*>(input.data_ptr()), \
|
||||
reinterpret_cast<OutT*>(output.data_ptr()), \
|
||||
kVecNumel, \
|
||||
kTail)
|
||||
|
||||
switch (kVecSize)
|
||||
{
|
||||
case 1:
|
||||
_(1);
|
||||
break;
|
||||
case 2:
|
||||
_(2);
|
||||
break;
|
||||
case 4:
|
||||
_(4);
|
||||
break;
|
||||
}
|
||||
#undef _
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void convert_fp8(torch::Tensor& input, torch::Tensor& output)
|
||||
{
|
||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte || output.scalar_type() == at::ScalarType::Byte, "Data type of Input or Output should be torch.uint8 for convert_fp8!");
|
||||
TORCH_CHECK(input.scalar_type() != output.scalar_type(), "Data type of input and output are the same!");
|
||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte ||
|
||||
input.scalar_type() == at::ScalarType::Float ||
|
||||
input.scalar_type() == at::ScalarType::Half ||
|
||||
input.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of input!");
|
||||
TORCH_CHECK(output.scalar_type() == at::ScalarType::Byte ||
|
||||
output.scalar_type() == at::ScalarType::Float ||
|
||||
output.scalar_type() == at::ScalarType::Half ||
|
||||
output.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of output!");
|
||||
TORCH_CHECK(input.sizes() == output.sizes(), "Shape of input and output should be the same!");
|
||||
|
||||
#define _(InT, OutT) \
|
||||
apply_convert_fp8<InT, OutT>(input, output)
|
||||
|
||||
|
||||
if(input.scalar_type() == at::ScalarType::Byte)
|
||||
{
|
||||
if(output.scalar_type() == at::ScalarType::Float)
|
||||
{
|
||||
_(uint8_t, float);
|
||||
}
|
||||
else if(output.scalar_type() == at::ScalarType::Half)
|
||||
{
|
||||
_(uint8_t, half);
|
||||
}
|
||||
else if(output.scalar_type() == at::ScalarType::BFloat16)
|
||||
{
|
||||
_(uint8_t, __nv_bfloat16);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(input.scalar_type() == at::ScalarType::Float)
|
||||
{
|
||||
_(float, uint8_t);
|
||||
}
|
||||
else if(input.scalar_type() == at::ScalarType::Half)
|
||||
{
|
||||
_(half, uint8_t);
|
||||
}
|
||||
else if(input.scalar_type() == at::ScalarType::BFloat16)
|
||||
{
|
||||
_(__nv_bfloat16, uint8_t);
|
||||
}
|
||||
}
|
||||
|
||||
#undef _
|
||||
}
|
@@ -1,9 +1,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "common/vec_type_traits.h"
|
||||
#include "funcs/cast_functor.h"
|
||||
|
||||
@@ -12,9 +9,9 @@ namespace cuda {
|
||||
namespace utils {
|
||||
|
||||
// Note(LiuYang): Depreciated
|
||||
template <typename T, int vec_size>
|
||||
template <typename T, int VecSize>
|
||||
__device__ __inline__ void copy_vector(T *dst, const T *src) {
|
||||
using VT = typename common::VecTypeTrait<T, vec_size>::Type;
|
||||
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
|
||||
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
|
||||
}
|
||||
|
||||
@@ -34,17 +31,17 @@ __device__ __inline__ void copy_zero_vector(T *dst) {
|
||||
*(reinterpret_cast<VT *>(dst)) = funcs::CastFunctor<float, VT>()(0.0f);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, int vec_size>
|
||||
template <typename SrcT, typename DstT, int VecSize>
|
||||
__device__ __inline__ void copy(const SrcT *src, DstT *dst) {
|
||||
using SrcVT = typename common::VecTypeTrait<SrcT, vec_size>::Type;
|
||||
using DstVT = typename common::VecTypeTrait<DstT, vec_size>::Type;
|
||||
using SrcVT = typename common::VecTypeTrait<SrcT, VecSize>::Type;
|
||||
using DstVT = typename common::VecTypeTrait<DstT, VecSize>::Type;
|
||||
*(reinterpret_cast<DstVT *>(dst)) = funcs::CastFunctor<SrcVT, DstVT>()(
|
||||
*(reinterpret_cast<const SrcVT *>(src)));
|
||||
}
|
||||
|
||||
template <typename T, int vec_size>
|
||||
template <typename T, int VecSize>
|
||||
__device__ __inline__ void copy(const T *src, T *dst) {
|
||||
using VT = typename common::VecTypeTrait<T, vec_size>::Type;
|
||||
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
|
||||
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
|
||||
}
|
||||
|
||||
|
@@ -75,6 +75,8 @@ void flash_decoding_attention(
|
||||
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
|
||||
const c10::optional<torch::Tensor>& alibi_slopes, float scale);
|
||||
|
||||
void convert_fp8(torch::Tensor& input, torch::Tensor& output);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||
"Copy the GPU memory of kvcache during the decode stage.");
|
||||
@@ -102,4 +104,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("flash_decoding_attention", &flash_decoding_attention,
|
||||
"Compute the attention between an input query and the cached "
|
||||
"keys/values using PagedAttention.");
|
||||
|
||||
m.def("convert_fp8", &convert_fp8,
|
||||
"Convert input to fp8 output or convert fp8 input to output.");
|
||||
}
|
||||
|
@@ -17,6 +17,7 @@ class InferenceOpsCudaExtension(_CudaExtension):
|
||||
"kernel/cuda/rms_layernorm_kernel.cu",
|
||||
"kernel/cuda/get_cos_and_sin_kernel.cu",
|
||||
"kernel/cuda/flash_decoding_attention_kernel.cu",
|
||||
"kernel/cuda/convert_fp8_kernel.cu",
|
||||
]
|
||||
] + [self.pybind_abs_path("inference/inference.cpp")]
|
||||
return ret
|
||||
|
Reference in New Issue
Block a user