mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[inference] Refactor inference architecture (#5057)
* [inference] support only TP (#4998) * support only tp * enable tp * add support for bloom (#5008) * [refactor] refactor gptq and smoothquant llama (#5012) * refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcuda * [Inference Refactor] Merge chatglm2 with pp and tp (#5023) merge chatglm with pp and tp * [Refactor] remove useless inference code (#5022) * remove useless code * fix quant model * fix test import bug * mv original inference legacy * fix chatglm2 * [Refactor] refactor policy search and quant type controlling in inference (#5035) * [Refactor] refactor policy search and quant type controling in inference * [inference] update readme (#5051) * update readme * update readme * fix architecture * fix table * fix table * [inference] udpate example (#5053) * udpate example * fix run.sh * fix rebase bug * fix some errors * update readme * add some features * update interface * update readme * update benchmark * add requirements-infer --------- Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
179
colossalai/legacy/inference/quant/smoothquant/models/linear.py
Normal file
179
colossalai/legacy/inference/quant/smoothquant/models/linear.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
|
||||
|
||||
import torch
|
||||
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
|
||||
from torch_int.functional.quantization import quantize_per_tensor_absmax
|
||||
|
||||
try:
|
||||
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
|
||||
|
||||
smoothquant_cuda = SmoothquantBuilder().load()
|
||||
HAS_SMOOTHQUANT_CUDA = True
|
||||
except ImportError:
|
||||
HAS_SMOOTHQUANT_CUDA = False
|
||||
raise ImportError("CUDA smoothquant linear is not installed")
|
||||
|
||||
|
||||
class W8A8BFP32O32LinearSiLU(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.register_buffer(
|
||||
"weight",
|
||||
torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(self.out_features, self.in_features),
|
||||
dtype=torch.int8,
|
||||
requires_grad=False,
|
||||
),
|
||||
)
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False),
|
||||
)
|
||||
self.register_buffer("a", torch.tensor(alpha))
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.weight = self.weight.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x_shape[-1])
|
||||
y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0)
|
||||
y = y.view(*x_shape[:-1], -1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def from_float(module: torch.nn.Linear, input_scale):
|
||||
int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features)
|
||||
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
|
||||
alpha = input_scale * weight_scale
|
||||
int8_module.weight = int8_weight
|
||||
if module.bias is not None:
|
||||
int8_module.bias.data.copy_(module.bias.to(torch.float))
|
||||
int8_module.a = alpha
|
||||
return int8_module
|
||||
|
||||
|
||||
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
|
||||
class W8A8B8O8Linear(torch.nn.Module):
|
||||
# For qkv_proj
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.register_buffer(
|
||||
"weight",
|
||||
torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(self.out_features, self.in_features),
|
||||
dtype=torch.int8,
|
||||
requires_grad=False,
|
||||
),
|
||||
)
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False),
|
||||
)
|
||||
self.register_buffer("a", torch.tensor(alpha))
|
||||
self.register_buffer("b", torch.tensor(beta))
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.weight = self.weight.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x_shape[-1])
|
||||
y = linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item())
|
||||
y = y.view(*x_shape[:-1], -1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def from_float(module: torch.nn.Linear, input_scale, output_scale):
|
||||
int8_module = W8A8B8O8Linear(module.in_features, module.out_features)
|
||||
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
|
||||
alpha = input_scale * weight_scale / output_scale
|
||||
int8_module.weight = int8_weight
|
||||
int8_module.a = alpha
|
||||
|
||||
if module.bias is not None:
|
||||
int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias)
|
||||
int8_module.bias = int8_bias
|
||||
beta = bias_scale / output_scale
|
||||
int8_module.b = beta
|
||||
|
||||
return int8_module
|
||||
|
||||
|
||||
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
|
||||
class W8A8BFP32OFP32Linear(torch.nn.Module):
|
||||
# For fc2 and out_proj
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.register_buffer(
|
||||
"weight",
|
||||
torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(self.out_features, self.in_features),
|
||||
dtype=torch.int8,
|
||||
requires_grad=False,
|
||||
),
|
||||
)
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False),
|
||||
)
|
||||
self.register_buffer("a", torch.tensor(alpha))
|
||||
|
||||
def _apply(self, fn):
|
||||
# prevent the bias from being converted to half
|
||||
super()._apply(fn)
|
||||
self.bias = self.bias.to(torch.float32)
|
||||
return self
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.weight = self.weight.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(torch.float32)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x_shape[-1])
|
||||
y = linear_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1)
|
||||
y = y.view(*x_shape[:-1], -1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def from_float(module: torch.nn.Linear, input_scale):
|
||||
int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features)
|
||||
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
|
||||
alpha = input_scale * weight_scale
|
||||
int8_module.weight = int8_weight
|
||||
int8_module.a = alpha
|
||||
int8_module.input_scale = input_scale
|
||||
int8_module.weight_scale = weight_scale
|
||||
|
||||
if module.bias is not None:
|
||||
int8_module.bias = module.bias.to(torch.float32)
|
||||
|
||||
return int8_module
|
Reference in New Issue
Block a user