mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 04:55:25 +00:00
[inference] Refactor inference architecture (#5057)
* [inference] support only TP (#4998) * support only tp * enable tp * add support for bloom (#5008) * [refactor] refactor gptq and smoothquant llama (#5012) * refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcuda * [Inference Refactor] Merge chatglm2 with pp and tp (#5023) merge chatglm with pp and tp * [Refactor] remove useless inference code (#5022) * remove useless code * fix quant model * fix test import bug * mv original inference legacy * fix chatglm2 * [Refactor] refactor policy search and quant type controlling in inference (#5035) * [Refactor] refactor policy search and quant type controling in inference * [inference] update readme (#5051) * update readme * update readme * fix architecture * fix table * fix table * [inference] udpate example (#5053) * udpate example * fix run.sh * fix rebase bug * fix some errors * update readme * add some features * update interface * update readme * update benchmark * add requirements-infer --------- Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
58
colossalai/legacy/inference/quant/gptq/cai_gptq/gptq_op.py
Normal file
58
colossalai/legacy/inference/quant/gptq/cai_gptq/gptq_op.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
|
||||
from colossalai.kernel.triton import gptq_fused_linear_triton
|
||||
|
||||
|
||||
class CaiGPTQLinearOp(torch.nn.Module):
|
||||
def __init__(self, gptq_group_size, gptq_quant_bits):
|
||||
super(CaiGPTQLinearOp, self).__init__()
|
||||
self.group_size = gptq_group_size
|
||||
self.bits = gptq_quant_bits
|
||||
self.maxq = 2**self.bits - 1
|
||||
self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scales: torch.Tensor,
|
||||
weight_zeros: torch.Tensor,
|
||||
g_idx: torch.Tensor = None,
|
||||
act_type=0,
|
||||
bias: torch.Tensor = None,
|
||||
residual: torch.Tensor = None,
|
||||
qkv_fused=False,
|
||||
):
|
||||
add_bias = True
|
||||
if bias is None:
|
||||
bias = self.empty_tensor
|
||||
add_bias = False
|
||||
|
||||
add_residual = True
|
||||
if residual is None:
|
||||
residual = self.empty_tensor
|
||||
add_residual = False
|
||||
x = input.view(-1, input.shape[-1])
|
||||
|
||||
out = gptq_fused_linear_triton(
|
||||
x,
|
||||
weight,
|
||||
weight_scales,
|
||||
weight_zeros,
|
||||
bias,
|
||||
residual,
|
||||
self.bits,
|
||||
self.maxq,
|
||||
self.group_size,
|
||||
qkv_fused,
|
||||
add_bias,
|
||||
add_residual,
|
||||
act_type=act_type,
|
||||
g_idx=g_idx,
|
||||
)
|
||||
if qkv_fused:
|
||||
out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
|
||||
else:
|
||||
out = out.view(input.shape[0], input.shape[1], weight.shape[-1])
|
||||
|
||||
return out
|
Reference in New Issue
Block a user