mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[Inference]Repalce Attention layer and MLP layer by shardformer to optimize the weight transpose operation,add fused_qkv and fused linear_add (#5340)
* add fused qkv * replace attn and mlp by shardformer * fix bugs in mlp * add docstrings * fix test_inference_engine.py * add optimize unbind * add fused_addmm * rm squeeze(1) * refactor codes * fix ci bugs * rename ShardFormerLlamaMLP and ShardFormerLlamaAttention * Removed the dependency on LlamaFlashAttention2 * rollback test_inference_engine.py
This commit is contained in:
@@ -1,25 +1,18 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaFlashAttention2,
|
||||
LlamaForCausalLM,
|
||||
LlamaMLP,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
LlamaSdpaAttention,
|
||||
)
|
||||
from torch.nn import Parameter
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
||||
|
||||
from colossalai.inference.modeling.models.nopadding_llama import (
|
||||
llama_attn_forward,
|
||||
NopadLlamaAttention,
|
||||
NopadLlamaMLP,
|
||||
llama_causal_lm_forward,
|
||||
llama_decoder_layer_forward,
|
||||
llama_model_forward,
|
||||
nopad_mlp,
|
||||
)
|
||||
from colossalai.inference.utils import init_to_get_rotary
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
|
||||
# import colossalai
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
@@ -50,6 +43,27 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
decoder_attribute_replacement = {
|
||||
"lm_head.weight": Parameter(self.model.lm_head.weight.transpose(0, 1), requires_grad=False),
|
||||
}
|
||||
policy[LlamaForCausalLM] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp",
|
||||
target_module=NopadLlamaMLP,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn",
|
||||
target_module=NopadLlamaAttention,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.shard_config._infer()
|
||||
|
||||
infer_forward = llama_causal_lm_forward
|
||||
@@ -68,28 +82,6 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
||||
)
|
||||
|
||||
infer_forward = nopad_mlp
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaMLP)
|
||||
|
||||
infer_forward = llama_attn_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaAttention
|
||||
)
|
||||
|
||||
infer_forward = llama_attn_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
|
||||
)
|
||||
|
||||
infer_forward = llama_attn_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
|
||||
)
|
||||
|
||||
infer_forward = None
|
||||
if HAS_TRITON_RMSNORM:
|
||||
infer_forward = get_triton_rmsnorm_forward()
|
||||
|
Reference in New Issue
Block a user