[Inference/Feat] Add kvcache quant support for fused_rotary_embedding_cache_copy (#5680)

This commit is contained in:
傅剑寒
2024-04-30 18:33:53 +08:00
committed by GitHub
parent 5cd75ce4c7
commit ef8e4ffe31
7 changed files with 226 additions and 125 deletions

View File

@@ -380,12 +380,6 @@ void flash_decoding_attention(
const c10::optional<torch::Tensor>& alibi_slopes,
float scale) {
TORCH_CHECK(query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16,
"Dtype of query should be float, half or bfloat16!");
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == query.scalar_type(),
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
if(key_cache.scalar_type() == at::ScalarType::Byte)
{
switch (query.scalar_type()) {