add support for bloom (#5008)

This commit is contained in:
Bin Jia
2023-11-06 09:35:33 +08:00
committed by FoolPlayer
parent f747d13040
commit 48d0a58d10
8 changed files with 721 additions and 25 deletions

View File

@@ -16,6 +16,7 @@ PP_AXIS, TP_AXIS = 0, 1
_supported_models = [
"LlamaForCausalLM",
"BloomForCausalLM",
]
@@ -155,12 +156,20 @@ class CaiInferEngine:
def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
head_dim = model.config.hidden_size // model.config.num_attention_heads
head_num = model.config.num_attention_heads
num_hidden_layers = (
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
)
layer_num = num_hidden_layers // self.pp_size
if model.config.model_type == "llama":
head_dim = model.config.hidden_size // model.config.num_attention_heads
head_num = model.config.num_attention_heads // self.tp_size
num_hidden_layers = (
model.config.num_hidden_layers
if hasattr(model.config, "num_hidden_layers")
else model.config.num_layers
)
layer_num = num_hidden_layers // self.pp_size
elif model.config.model_type == "bloom":
head_dim = model.config.hidden_size // model.config.n_head
head_num = model.config.n_head // self.tp_size
num_hidden_layers = model.config.n_layer
layer_num = num_hidden_layers // self.pp_size
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
return cache_manager