ColossalAI/colossalai/inference/tensor_parallel/policies/bloom.py
Xu Kai 946ab56c48
[feature] add gptq for inference (#4754)
* [gptq] add gptq kernel (#4416)

* add gptq

* refactor code

* fix tests

* replace auto-gptq

* rname inferance/quant

* refactor test

* add auto-gptq as an option

* reset requirements

* change assert and check auto-gptq

* add import warnings

* change test flash attn version

* remove example

* change requirements of flash_attn

* modify tests

* [skip ci] change requirements-test

* [gptq] faster gptq cuda kernel (#4494)

* [skip ci] add cuda kernels

* add license

* [skip ci] fix max_input_len

* format files & change test size

* [skip ci]

* [gptq] add gptq tensor parallel (#4538)

* add gptq tensor parallel

* add gptq tp

* delete print

* add test gptq check

* add test auto gptq check

* [gptq] combine gptq and kv cache manager (#4706)

* combine gptq and kv cache manager

* add init bits

* delete useless code

* add model path

* delete usless print and update test

* delete usless import

* move option gptq to shard config

* change replace linear to shardformer

* update bloom policy

* delete useless code

* fix import bug and delete uselss code

* change colossalai/gptq to colossalai/quant/gptq

* update import linear for tests

* delete useless code and mv gptq_kernel to kernel directory

* fix triton kernel

* add triton import
2023-09-22 11:02:50 +08:00

100 lines
5.2 KiB
Python

from functools import partial
import torch
from torch.nn import LayerNorm
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
from ..modeling.bloom import BloomInferenceForwards
try:
from colossalai.kernel.triton import layer_norm
HAS_TRITON_NORM = True
except:
print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
HAS_TRITON_NORM = False
def get_triton_layernorm_forward():
if HAS_TRITON_NORM:
def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor):
return layer_norm(hidden_states, self.weight.data, self.bias, self.eps)
return _triton_layernorm_forward
else:
return None
class BloomModelInferPolicy(BloomForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
policy = super().module_policy()
if self.shard_config.inference_gptq:
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 3}),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=RowCaiQuantLinear,
kwargs={'split_num': 1}),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 1}),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=RowCaiQuantLinear,
kwargs={'split_num': 1}),
])
# NOTE set inference mode to shard config
self.shard_config._infer()
method_replacement = {
"forward": BloomInferenceForwards.bloom_for_causal_lm_forward,
"prepare_inputs_for_generation": BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation,
}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=BloomForCausalLM
)
method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel)
method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock)
method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=BloomAttention
)
if HAS_TRITON_NORM:
infer_method = get_triton_layernorm_forward()
method_replacement = {"forward": partial(infer_method)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LayerNorm
)
return policy