diff --git a/colossalai/inference/engine/modeling/chatglm2.py b/colossalai/inference/engine/modeling/chatglm2.py index 56e777bb2..2eac235d5 100644 --- a/colossalai/inference/engine/modeling/chatglm2.py +++ b/colossalai/inference/engine/modeling/chatglm2.py @@ -408,12 +408,20 @@ class ChatGLM2InferenceForwards: query_layer = query_layer.reshape( -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head ) - key_layer = key_layer.reshape( - -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head - ) - value_layer = value_layer.reshape( - -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head - ) + if self.multi_query_attention: + key_layer = key_layer.reshape( + -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head + ) + value_layer = value_layer.reshape( + -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head + ) + else: + key_layer = key_layer.reshape( + -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ) + value_layer = value_layer.reshape( + -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ) if infer_state.is_context_stage: # first token generation: diff --git a/colossalai/inference/engine/policies/chatglm2.py b/colossalai/inference/engine/policies/chatglm2.py index 3e1d94f47..d040f93c4 100644 --- a/colossalai/inference/engine/policies/chatglm2.py +++ b/colossalai/inference/engine/policies/chatglm2.py @@ -1,5 +1,7 @@ +from functools import partial from typing import List +import torch import torch.nn as nn from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( @@ -7,23 +9,40 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMModel, GLMBlock, GLMTransformer, + RMSNorm, SelfAttention, ) # import colossalai -from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy +from colossalai.shardformer.policies.chatglm2 import ChatGLMForConditionalGenerationPolicy from ..modeling._utils import init_to_get_rotary from ..modeling.chatglm2 import ChatGLM2InferenceForwards try: + from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward + HAS_TRITON_RMSNORM = True except: - print("you should install triton from https://github.com/openai/triton") + print("Did not find rms-norm triton kernel") + print( + "You can use the following command to install: pip install git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8" + ) HAS_TRITON_RMSNORM = False -class ChatGLM2InferPolicy(ChatGLMModelPolicy): +def get_triton_rmsnorm_forward(): + if HAS_TRITON_RMSNORM: + + def _triton_rmsnorm_forward(self: RMSNorm, hidden_states: torch.Tensor): + return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.eps) + + return _triton_rmsnorm_forward + else: + raise RuntimeError("Did not find rms-norm triton kernel") + + +class ChatGLM2InferPolicy(ChatGLMForConditionalGenerationPolicy): def __init__(self) -> None: super().__init__() @@ -56,6 +75,16 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy): ) # for rmsnorm and others, we need to check the shape + infer_forward = None + + if HAS_TRITON_RMSNORM: + infer_forward = get_triton_rmsnorm_forward() + if infer_forward is not None: + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=RMSNorm + ) + self.set_pipeline_forward( model_cls=ChatGLMForConditionalGeneration, new_forward=ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward, diff --git a/colossalai/inference/engine/policies/llama.py b/colossalai/inference/engine/policies/llama.py index 397060258..36a71618f 100644 --- a/colossalai/inference/engine/policies/llama.py +++ b/colossalai/inference/engine/policies/llama.py @@ -21,6 +21,7 @@ from ..modeling.llama import LlamaInferenceForwards try: from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward + HAS_TRITON_RMSNORM = True except: print("you should install triton from https://github.com/openai/triton") @@ -30,6 +31,7 @@ if HAS_TRITON_RMSNORM: def get_triton_rmsnorm_forward(): def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) + return _triton_rmsnorm_forward