[Inference/Feat] Add kvcache quant support for fused_rotary_embedding_cache_copy (#5680)

This commit is contained in:
傅剑寒
2024-04-30 18:33:53 +08:00
committed by GitHub
parent 5cd75ce4c7
commit ef8e4ffe31
7 changed files with 226 additions and 125 deletions

View File

@@ -56,6 +56,11 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMin, HOSTDEVICE,
typename T)
#if defined(COLOSSAL_WITH_CUDA)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kMinus,
DEVICE, STMTS_WRAPPER({
return __hsub(lhs, rhs);
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kAdd,
DEVICE, STMTS_WRAPPER({
return __hadd(lhs, rhs);
@@ -71,6 +76,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
DEVICE, STMTS_WRAPPER({
return __hadd(lhs, rhs);
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
__nv_bfloat16, BinaryOpType::kMinus,
DEVICE, STMTS_WRAPPER({
return __hsub(lhs, rhs);
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162,
__nv_bfloat162, BinaryOpType::kAdd,
DEVICE, STMTS_WRAPPER({
@@ -82,6 +94,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
STMTS_WRAPPER({
return __float2bfloat16(__bfloat162float(lhs) + __bfloat162float(rhs));
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kMinus, DEVICE,
STMTS_WRAPPER({
return __float2bfloat16(__bfloat162float(lhs) - __bfloat162float(rhs));
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kAdd, DEVICE,
STMTS_WRAPPER({

View File

@@ -94,6 +94,10 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16, DEVICE,
STMTS_WRAPPER({
return __float2bfloat16_rn(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, float, DEVICE,
STMTS_WRAPPER({
return __bfloat162float(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE,
STMTS_WRAPPER({
dtype::bfloat164 dst;