fix bugs in request_handler.py and engine.py

This commit is contained in:
yuehuayingxueluo
2024-01-10 10:38:53 +08:00
committed by FrankLeeeee
parent 10e3c9f923
commit d40eb26029
4 changed files with 21 additions and 11 deletions

View File

@@ -58,7 +58,12 @@ class KVCacheManager:
# Parallel settings
self.tp_size = config.tp_size
# Model settings
self.dtype = config.dtype
if config.dtype == "fp32" or config.dtype == torch.float32:
self.dtype = torch.float32
elif config.dtype == "fp16" or config.dtype == torch.float16:
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
# For now we focus on MHA only, TODO add handling for MQA and GQA