[Inference/Feat] Feat quant kvcache step2 (#5674)

This commit is contained in:
傅剑寒
2024-04-30 11:26:36 +08:00
committed by GitHub
parent 8ccb6714e7
commit 808ee6e4ad
4 changed files with 208 additions and 71 deletions

View File

@@ -372,7 +372,7 @@ void flash_decoding_attention(
TORCH_CHECK(query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16,
"Dtype of query should be float, half or bfloat16!");
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key_cache.scalar_type(),
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == query.scalar_type(),
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
if(key_cache.scalar_type() == at::ScalarType::Byte)