diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 50546271e..302f379f9 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -15,14 +15,6 @@ __all__ = ["KVCacheManager"] GIGABYTE = 1024**3 -def get_model_config_attr(config: PretrainedConfig, attr_name: str): - if hasattr(config, attr_name): - return getattr(config, attr_name) - elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map[attr_name]): - return getattr(config, config.attribute_map[attr_name]) - raise AttributeError(f"{attr_name} is not found in config") - - class KVCacheManager: """KVCacheManager manages both the logical cache blocks and physical KV cache (tensors). @@ -53,7 +45,7 @@ class KVCacheManager: And it's possible to have a batch of sequences with different lengths of block tables. """ - def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None: + def __init__(self, config: InferenceConfig, model_config: PretrainedConfig) -> None: self.logger = get_dist_logger(__name__) self.device = get_current_device() @@ -62,14 +54,11 @@ class KVCacheManager: # Model settings self.dtype = config.dtype self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() - self.num_layers = get_model_config_attr(model_config, "num_hidden_layers") - self.head_num = get_model_config_attr(model_config, "num_attention_heads") - self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num - - if hasattr(config, "num_key_value_heads"): - self.kv_head_num = getattr(config, "num_key_value_heads") - elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]): - self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"]) + self.num_layers = model_config.num_hidden_layers + self.head_num = model_config.num_attention_heads + self.head_size = model_config.hidden_size // self.head_num + if hasattr(model_config, "num_key_value_heads"): + self.kv_head_num = model_config.num_key_value_heads else: self.kv_head_num = self.head_num diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 6e541f792..713175c6c 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -141,9 +141,11 @@ class LlamaPolicy(Policy): assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 ), f"The number of attention heads must be divisible by tensor parallel size." - assert ( - self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 - ), f"The number of key_value heads must be divisible by tensor parallel size." + if hasattr(self.model.config, "num_key_value_heads"): + assert ( + self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size + and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size." decoder_attribute_replacement = { "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,