mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[Inference] Adapt Baichuan2-13B TP (#5659)
* adapt to baichuan2 13B * add baichuan2 13B TP * update baichuan tp logic * rm unused code * Fix TP logic * fix alibi slopes tp logic * rm nn.Module * Polished the code. * change BAICHUAN_MODEL_NAME_OR_PATH * Modified the logic for loading Baichuan weights. * fix typos
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch.nn as nn
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.inference.modeling.layers.baichuan_tp_linear import (
|
||||
BaichuanLMHeadLinear1D_Col,
|
||||
BaichuanWpackLinear1D_Col,
|
||||
)
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import (
|
||||
NopadBaichuanAttention,
|
||||
NopadBaichuanMLP,
|
||||
@@ -12,6 +13,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
|
||||
llama_model_forward,
|
||||
)
|
||||
from colossalai.inference.utils import init_to_get_rotary
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
@@ -23,39 +25,72 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
decoder_attribute_replacement = {
|
||||
"lm_head.weight": Parameter(nn.functional.normalize(self.model.lm_head.weight), requires_grad=False),
|
||||
}
|
||||
policy["BaichuanForCausalLM"] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
|
||||
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
||||
)
|
||||
else:
|
||||
decoder_attribute_replacement = None
|
||||
|
||||
# used for relpacing Baichuan 7B/13B decoder layer
|
||||
for layer_name in ["DecoderLayer", "BaichuanLayer"]:
|
||||
policy[layer_name] = ModulePolicyDescription(
|
||||
# used for Baichuan 7B and 13B for baichuan DecoderLayer
|
||||
for DecoderLayer in ["DecoderLayer", "BaichuanLayer"]:
|
||||
policy[DecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp",
|
||||
target_module=NopadBaichuanMLP,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.W_pack",
|
||||
target_module=BaichuanWpackLinear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn",
|
||||
target_module=NopadBaichuanAttention,
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=layer_name
|
||||
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=DecoderLayer
|
||||
)
|
||||
|
||||
policy["BaichuanForCausalLM"] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=BaichuanLMHeadLinear1D_Col, kwargs={"gather_output": True}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": llama_causal_lm_forward}, policy=policy, target_key="BaichuanForCausalLM"
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": llama_model_forward}, policy=policy, target_key="BaichuanModel"
|
||||
)
|
||||
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": baichuan_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
|
||||
)
|
||||
|
Reference in New Issue
Block a user