mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-30 20:55:17 +00:00
* [inference] support only TP (#4998) * support only tp * enable tp * add support for bloom (#5008) * [refactor] refactor gptq and smoothquant llama (#5012) * refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcuda * [Inference Refactor] Merge chatglm2 with pp and tp (#5023) merge chatglm with pp and tp * [Refactor] remove useless inference code (#5022) * remove useless code * fix quant model * fix test import bug * mv original inference legacy * fix chatglm2 * [Refactor] refactor policy search and quant type controlling in inference (#5035) * [Refactor] refactor policy search and quant type controling in inference * [inference] update readme (#5051) * update readme * update readme * fix architecture * fix table * fix table * [inference] udpate example (#5053) * udpate example * fix run.sh * fix rebase bug * fix some errors * update readme * add some features * update interface * update readme * update benchmark * add requirements-infer --------- Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
62 lines
2.4 KiB
Python
62 lines
2.4 KiB
Python
import torch
|
|
|
|
|
|
class GPTQManager:
|
|
def __init__(self, quant_config, max_input_len: int = 1):
|
|
self.max_dq_buffer_size = 1
|
|
self.max_inner_outer_dim = 1
|
|
self.bits = quant_config.bits
|
|
self.use_act_order = quant_config.desc_act
|
|
self.max_input_len = 1
|
|
self.gptq_temp_state_buffer = None
|
|
self.gptq_temp_dq_buffer = None
|
|
self.quant_config = quant_config
|
|
|
|
def post_init_gptq_buffer(self, model: torch.nn.Module) -> None:
|
|
from .cai_gptq import CaiQuantLinear
|
|
|
|
HAS_GPTQ_CUDA = False
|
|
try:
|
|
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
|
|
|
gptq_cuda = GPTQBuilder().load()
|
|
HAS_GPTQ_CUDA = True
|
|
except ImportError:
|
|
warnings.warn("CUDA gptq is not installed")
|
|
HAS_GPTQ_CUDA = False
|
|
|
|
for name, submodule in model.named_modules():
|
|
if isinstance(submodule, CaiQuantLinear):
|
|
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
|
|
|
|
if self.use_act_order:
|
|
self.max_inner_outer_dim = max(
|
|
self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures
|
|
)
|
|
self.bits = submodule.bits
|
|
if not (HAS_GPTQ_CUDA and self.bits == 4):
|
|
return
|
|
|
|
max_input_len = 1
|
|
if self.use_act_order:
|
|
max_input_len = self.max_input_len
|
|
# The temp_state buffer is required to reorder X in the act-order case.
|
|
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
|
self.gptq_temp_state_buffer = torch.zeros(
|
|
(max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
|
|
)
|
|
self.gptq_temp_dq_buffer = torch.zeros(
|
|
(1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()
|
|
)
|
|
|
|
gptq_cuda.prepare_buffers(
|
|
torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer
|
|
)
|
|
# Using the default from exllama repo here.
|
|
matmul_recons_thd = 8
|
|
matmul_fused_remap = False
|
|
matmul_no_half2 = False
|
|
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
|
|
|
torch.cuda.empty_cache()
|