mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[Inference]Adapt to baichuan2 13B (#5614)
* adapt to baichuan2 13B * adapt to baichuan2 13B * change BAICHUAN_MODEL_NAME_OR_PATH * fix test_decoding_attn.py * Modifications based on review comments. * change BAICHUAN_MODEL_NAME_OR_PATH * mv attn mask processes to test flash decoding * mv get_alibi_slopes baichuan modeling * fix bugs in test_baichuan.py
This commit is contained in:
@@ -64,8 +64,15 @@ class KVCacheManager:
|
||||
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
|
||||
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
|
||||
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
|
||||
self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads")
|
||||
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
|
||||
|
||||
if hasattr(config, "num_key_value_heads"):
|
||||
self.kv_head_num = getattr(config, "num_key_value_heads")
|
||||
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
|
||||
self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
|
||||
else:
|
||||
self.kv_head_num = self.head_num
|
||||
|
||||
assert (
|
||||
self.kv_head_num % self.tp_size == 0
|
||||
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"
|
||||
|
Reference in New Issue
Block a user