mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-12 10:18:08 +00:00
add support for bloom (#5008)
This commit is contained in:
@@ -16,6 +16,7 @@ PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
_supported_models = [
|
||||
"LlamaForCausalLM",
|
||||
"BloomForCausalLM",
|
||||
]
|
||||
|
||||
|
||||
@@ -155,12 +156,20 @@ class CaiInferEngine:
|
||||
|
||||
def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
|
||||
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
|
||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||
head_num = model.config.num_attention_heads
|
||||
num_hidden_layers = (
|
||||
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
|
||||
)
|
||||
layer_num = num_hidden_layers // self.pp_size
|
||||
if model.config.model_type == "llama":
|
||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||
head_num = model.config.num_attention_heads // self.tp_size
|
||||
num_hidden_layers = (
|
||||
model.config.num_hidden_layers
|
||||
if hasattr(model.config, "num_hidden_layers")
|
||||
else model.config.num_layers
|
||||
)
|
||||
layer_num = num_hidden_layers // self.pp_size
|
||||
elif model.config.model_type == "bloom":
|
||||
head_dim = model.config.hidden_size // model.config.n_head
|
||||
head_num = model.config.n_head // self.tp_size
|
||||
num_hidden_layers = model.config.n_layer
|
||||
layer_num = num_hidden_layers // self.pp_size
|
||||
|
||||
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
|
||||
return cache_manager
|
||||
|
Reference in New Issue
Block a user