mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-14 22:23:23 +00:00
* [gptq] add gptq kernel (#4416) * add gptq * refactor code * fix tests * replace auto-gptq * rname inferance/quant * refactor test * add auto-gptq as an option * reset requirements * change assert and check auto-gptq * add import warnings * change test flash attn version * remove example * change requirements of flash_attn * modify tests * [skip ci] change requirements-test * [gptq] faster gptq cuda kernel (#4494) * [skip ci] add cuda kernels * add license * [skip ci] fix max_input_len * format files & change test size * [skip ci] * [gptq] add gptq tensor parallel (#4538) * add gptq tensor parallel * add gptq tp * delete print * add test gptq check * add test auto gptq check * [gptq] combine gptq and kv cache manager (#4706) * combine gptq and kv cache manager * add init bits * delete useless code * add model path * delete usless print and update test * delete usless import * move option gptq to shard config * change replace linear to shardformer * update bloom policy * delete useless code * fix import bug and delete uselss code * change colossalai/gptq to colossalai/quant/gptq * update import linear for tests * delete useless code and mv gptq_kernel to kernel directory * fix triton kernel * add triton import
59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
import torch
|
|
|
|
from colossalai.kernel.triton import gptq_fused_linear_triton
|
|
|
|
|
|
class CaiGPTQLinearOp(torch.nn.Module):
|
|
def __init__(self, gptq_group_size, gptq_quant_bits):
|
|
super(CaiGPTQLinearOp, self).__init__()
|
|
self.group_size = gptq_group_size
|
|
self.bits = gptq_quant_bits
|
|
self.maxq = 2**self.bits - 1
|
|
self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device())
|
|
|
|
def forward(
|
|
self,
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
weight_scales: torch.Tensor,
|
|
weight_zeros: torch.Tensor,
|
|
g_idx: torch.Tensor = None,
|
|
act_type=0,
|
|
bias: torch.Tensor = None,
|
|
residual: torch.Tensor = None,
|
|
qkv_fused=False,
|
|
):
|
|
add_bias = True
|
|
if bias is None:
|
|
bias = self.empty_tensor
|
|
add_bias = False
|
|
|
|
add_residual = True
|
|
if residual is None:
|
|
residual = self.empty_tensor
|
|
add_residual = False
|
|
x = input.view(-1, input.shape[-1])
|
|
|
|
out = gptq_fused_linear_triton(
|
|
x,
|
|
weight,
|
|
weight_scales,
|
|
weight_zeros,
|
|
bias,
|
|
residual,
|
|
self.bits,
|
|
self.maxq,
|
|
self.group_size,
|
|
qkv_fused,
|
|
add_bias,
|
|
add_residual,
|
|
act_type=act_type,
|
|
g_idx=g_idx,
|
|
)
|
|
if qkv_fused:
|
|
out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
|
|
else:
|
|
out = out.view(input.shape[0], input.shape[1], weight.shape[-1])
|
|
|
|
return out
|