[inference] Refactor inference architecture (#5057)

* [inference] support only TP (#4998)

* support only tp

* enable tp

* add support for bloom (#5008)

* [refactor] refactor gptq and smoothquant llama (#5012)

* refactor gptq and smoothquant llama

* fix import error

* fix linear import torch-int

* fix smoothquant llama import error

* fix import accelerate error

* fix bug

* fix import smooth cuda

* fix smoothcuda

* [Inference Refactor] Merge chatglm2 with pp and tp (#5023)

merge chatglm with pp and tp

* [Refactor] remove useless inference code (#5022)

* remove useless code

* fix quant model

* fix test import bug

* mv original inference legacy

* fix chatglm2

* [Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference

* [inference] update readme (#5051)

* update readme

* update readme

* fix architecture

* fix table

* fix table

* [inference] udpate example (#5053)

* udpate example

* fix run.sh

* fix rebase bug

* fix some errors

* update readme

* add some features

* update interface

* update readme

* update benchmark

* add requirements-infer

---------

Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
Xu Kai
2023-11-19 21:05:05 +08:00
committed by GitHub
parent bc09b95f50
commit fd6482ad8c
115 changed files with 6027 additions and 1431 deletions

View File

@@ -0,0 +1,77 @@
from functools import partial
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
ChatGLMModel,
GLMBlock,
GLMTransformer,
SelfAttention,
)
# import colossalai
from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy
from ..modeling._utils import init_to_get_rotary
from ..modeling.chatglm2 import ChatGLM2InferenceForwards
try:
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False
class ChatGLM2InferPolicy(ChatGLMModelPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
self.shard_config._infer()
model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward
method_replacement = {"forward": model_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel)
encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward
method_replacement = {"forward": encoder_infer_forward}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=GLMTransformer
)
encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward
method_replacement = {"forward": encoder_layer_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock)
attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward
method_replacement = {"forward": attn_infer_forward}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=SelfAttention
)
if self.shard_config.enable_tensor_parallelism:
policy[GLMBlock].attribute_replacement["self_attention.num_multi_query_groups_per_partition"] = (
self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
)
# for rmsnorm and others, we need to check the shape
return policy
def postprocess(self):
init_to_get_rotary(self.model)
return self.model
class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward
method_replacement = {"forward": partial(model_infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=ChatGLMForConditionalGeneration
)
return policy
def postprocess(self):
return super().postprocess()