[Inference]refactor baichuan (#5791)

* refactor baichuan

* remove unused code and add TODO for lazyinit
This commit is contained in:
Runyu Lu
2024-06-11 10:52:01 +08:00
committed by GitHub
parent 77a219a082
commit c0948aff97
3 changed files with 24 additions and 110 deletions

View File

@@ -1,8 +1,5 @@
from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.layers.baichuan_tp_linear import (
BaichuanLMHeadLinear1D_Col,
BaichuanWpackLinear1D_Col,
)
from colossalai.inference.modeling.layers.baichuan_tp_linear import BaichuanLMHeadLinear1D_Col
from colossalai.inference.modeling.models.nopadding_baichuan import (
NopadBaichuanAttention,
NopadBaichuanMLP,
@@ -14,7 +11,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
llama_model_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import FusedLinear1D_Col, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
@@ -60,8 +57,7 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
target_module=NopadBaichuanMLP,
),
SubModuleReplacementDescription(
suffix="self_attn.W_pack",
target_module=BaichuanWpackLinear1D_Col,
suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",