[Inference/Feat] Add quant kvcache interface (#5700)

* add quant kvcache interface

* delete unused output

* complete args comments
This commit is contained in:
傅剑寒
2024-05-09 18:03:24 +08:00
committed by GitHub
parent 492520dbdb
commit bfad39357b
2 changed files with 16 additions and 2 deletions

View File

@@ -53,6 +53,12 @@ class KVCacheManager:
self.tp_size = config.tp_size
# Model settings
self.dtype = config.dtype
if config.kv_cache_dtype is None:
self.kv_cache_dtype = config.dtype
else:
self.kv_cache_dtype = config.kv_cache_dtype
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = model_config.num_hidden_layers
self.head_num = model_config.num_attention_heads
@@ -488,6 +494,6 @@ class KVCacheManager:
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
k_cache.append(torch.zeros(kalloc_shape, dtype=self.dtype, device=self.device))
v_cache.append(torch.zeros(valloc_shape, dtype=self.dtype, device=self.device))
k_cache.append(torch.zeros(kalloc_shape, dtype=self.kv_cache_dtype, device=self.device))
v_cache.append(torch.zeros(valloc_shape, dtype=self.kv_cache_dtype, device=self.device))
return k_cache, v_cache