mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-29 16:57:55 +00:00
* [inference] support only TP (#4998) * support only tp * enable tp * add support for bloom (#5008) * [refactor] refactor gptq and smoothquant llama (#5012) * refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcuda * [Inference Refactor] Merge chatglm2 with pp and tp (#5023) merge chatglm with pp and tp * [Refactor] remove useless inference code (#5022) * remove useless code * fix quant model * fix test import bug * mv original inference legacy * fix chatglm2 * [Refactor] refactor policy search and quant type controlling in inference (#5035) * [Refactor] refactor policy search and quant type controling in inference * [inference] update readme (#5051) * update readme * update readme * fix architecture * fix table * fix table * [inference] udpate example (#5053) * udpate example * fix run.sh * fix rebase bug * fix some errors * update readme * add some features * update interface * update readme * update benchmark * add requirements-infer --------- Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
import torch
|
|
|
|
from colossalai.kernel.triton import gptq_fused_linear_triton
|
|
|
|
|
|
class CaiGPTQLinearOp(torch.nn.Module):
|
|
def __init__(self, gptq_group_size, gptq_quant_bits):
|
|
super(CaiGPTQLinearOp, self).__init__()
|
|
self.group_size = gptq_group_size
|
|
self.bits = gptq_quant_bits
|
|
self.maxq = 2**self.bits - 1
|
|
self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device())
|
|
|
|
def forward(
|
|
self,
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
weight_scales: torch.Tensor,
|
|
weight_zeros: torch.Tensor,
|
|
g_idx: torch.Tensor = None,
|
|
act_type=0,
|
|
bias: torch.Tensor = None,
|
|
residual: torch.Tensor = None,
|
|
qkv_fused=False,
|
|
):
|
|
add_bias = True
|
|
if bias is None:
|
|
bias = self.empty_tensor
|
|
add_bias = False
|
|
|
|
add_residual = True
|
|
if residual is None:
|
|
residual = self.empty_tensor
|
|
add_residual = False
|
|
x = input.view(-1, input.shape[-1])
|
|
|
|
out = gptq_fused_linear_triton(
|
|
x,
|
|
weight,
|
|
weight_scales,
|
|
weight_zeros,
|
|
bias,
|
|
residual,
|
|
self.bits,
|
|
self.maxq,
|
|
self.group_size,
|
|
qkv_fused,
|
|
add_bias,
|
|
add_residual,
|
|
act_type=act_type,
|
|
g_idx=g_idx,
|
|
)
|
|
if qkv_fused:
|
|
out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
|
|
else:
|
|
out = out.view(input.shape[0], input.shape[1], weight.shape[-1])
|
|
|
|
return out
|