mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[Inference/Feat] Add kvcache quant support for fused_rotary_embedding_cache_copy (#5680)
This commit is contained in:
@@ -56,6 +56,11 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMin, HOSTDEVICE,
|
||||
typename T)
|
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kMinus,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hsub(lhs, rhs);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kAdd,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hadd(lhs, rhs);
|
||||
@@ -71,6 +76,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hadd(lhs, rhs);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
|
||||
__nv_bfloat16, BinaryOpType::kMinus,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hsub(lhs, rhs);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162,
|
||||
__nv_bfloat162, BinaryOpType::kAdd,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
@@ -82,6 +94,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
STMTS_WRAPPER({
|
||||
return __float2bfloat16(__bfloat162float(lhs) + __bfloat162float(rhs));
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kMinus, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __float2bfloat16(__bfloat162float(lhs) - __bfloat162float(rhs));
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kAdd, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
|
@@ -94,6 +94,10 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __float2bfloat16_rn(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, float, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __bfloat162float(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::bfloat164 dst;
|
||||
|
Reference in New Issue
Block a user