ColossalAI/colossalai/legacy/inference/quant/gptq/cai_gptq/gptq_op.py
Xu Kai fd6482ad8c
[inference] Refactor inference architecture (#5057)
* [inference] support only TP (#4998)

* support only tp

* enable tp

* add support for bloom (#5008)

* [refactor] refactor gptq and smoothquant llama (#5012)

* refactor gptq and smoothquant llama

* fix import error

* fix linear import torch-int

* fix smoothquant llama import error

* fix import accelerate error

* fix bug

* fix import smooth cuda

* fix smoothcuda

* [Inference Refactor] Merge chatglm2 with pp and tp (#5023)

merge chatglm with pp and tp

* [Refactor] remove useless inference code (#5022)

* remove useless code

* fix quant model

* fix test import bug

* mv original inference legacy

* fix chatglm2

* [Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference

* [inference] update readme (#5051)

* update readme

* update readme

* fix architecture

* fix table

* fix table

* [inference] udpate example (#5053)

* udpate example

* fix run.sh

* fix rebase bug

* fix some errors

* update readme

* add some features

* update interface

* update readme

* update benchmark

* add requirements-infer

---------

Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
2023-11-19 21:05:05 +08:00

59 lines
1.6 KiB
Python

import torch
from colossalai.kernel.triton import gptq_fused_linear_triton
class CaiGPTQLinearOp(torch.nn.Module):
def __init__(self, gptq_group_size, gptq_quant_bits):
super(CaiGPTQLinearOp, self).__init__()
self.group_size = gptq_group_size
self.bits = gptq_quant_bits
self.maxq = 2**self.bits - 1
self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device())
def forward(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scales: torch.Tensor,
weight_zeros: torch.Tensor,
g_idx: torch.Tensor = None,
act_type=0,
bias: torch.Tensor = None,
residual: torch.Tensor = None,
qkv_fused=False,
):
add_bias = True
if bias is None:
bias = self.empty_tensor
add_bias = False
add_residual = True
if residual is None:
residual = self.empty_tensor
add_residual = False
x = input.view(-1, input.shape[-1])
out = gptq_fused_linear_triton(
x,
weight,
weight_scales,
weight_zeros,
bias,
residual,
self.bits,
self.maxq,
self.group_size,
qkv_fused,
add_bias,
add_residual,
act_type=act_type,
g_idx=g_idx,
)
if qkv_fused:
out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
else:
out = out.view(input.shape[0], input.shape[1], weight.shape[-1])
return out