mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[Inference]refactor baichuan (#5791)
* refactor baichuan * remove unused code and add TODO for lazyinit
This commit is contained in:
@@ -1,8 +1,5 @@
|
||||
from colossalai.inference.config import RPC_PARAM
|
||||
from colossalai.inference.modeling.layers.baichuan_tp_linear import (
|
||||
BaichuanLMHeadLinear1D_Col,
|
||||
BaichuanWpackLinear1D_Col,
|
||||
)
|
||||
from colossalai.inference.modeling.layers.baichuan_tp_linear import BaichuanLMHeadLinear1D_Col
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import (
|
||||
NopadBaichuanAttention,
|
||||
NopadBaichuanMLP,
|
||||
@@ -14,7 +11,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
|
||||
llama_model_forward,
|
||||
)
|
||||
from colossalai.inference.utils import init_to_get_rotary
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer import FusedLinear1D_Col, Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
@@ -60,8 +57,7 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
||||
target_module=NopadBaichuanMLP,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.W_pack",
|
||||
target_module=BaichuanWpackLinear1D_Col,
|
||||
suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
|
Reference in New Issue
Block a user