mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 01:24:04 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -2,6 +2,6 @@ from .llama_gptq import load_quant as llama_load_quant
|
||||
from .utils import low_resource_init
|
||||
|
||||
__all__ = [
|
||||
'llama_load_quant',
|
||||
'low_resource_init',
|
||||
"llama_load_quant",
|
||||
"low_resource_init",
|
||||
]
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from .loader import load_quant
|
||||
|
||||
__all__ = [
|
||||
'load_quant',
|
||||
"load_quant",
|
||||
]
|
||||
|
@@ -11,14 +11,15 @@ def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int):
|
||||
|
||||
# ignore lm head
|
||||
layers = find_layers(model)
|
||||
for name in ['lm_head']:
|
||||
for name in ["lm_head"]:
|
||||
if name in layers:
|
||||
del layers[name]
|
||||
|
||||
make_quant(model, layers, wbits, groupsize)
|
||||
|
||||
if checkpoint.endswith('.safetensors'):
|
||||
if checkpoint.endswith(".safetensors"):
|
||||
from safetensors.torch import load_file as safe_load
|
||||
|
||||
model.load_state_dict(safe_load(checkpoint))
|
||||
else:
|
||||
model.load_state_dict(torch.load(checkpoint))
|
||||
|
@@ -1,13 +1,12 @@
|
||||
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
||||
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
|
||||
if type(module) in layers:
|
||||
return {name: module}
|
||||
res = {}
|
||||
for name1, child in module.named_children():
|
||||
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
|
||||
res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1))
|
||||
return res
|
||||
|
@@ -13,14 +13,13 @@ def quantize(x, scale, zero, maxq):
|
||||
|
||||
|
||||
class Quantizer(nn.Module):
|
||||
|
||||
def __init__(self, shape=1):
|
||||
super(Quantizer, self).__init__()
|
||||
self.register_buffer('maxq', torch.tensor(0))
|
||||
self.register_buffer('scale', torch.zeros(shape))
|
||||
self.register_buffer('zero', torch.zeros(shape))
|
||||
self.register_buffer("maxq", torch.tensor(0))
|
||||
self.register_buffer("scale", torch.zeros(shape))
|
||||
self.register_buffer("zero", torch.zeros(shape))
|
||||
|
||||
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8):
|
||||
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8):
|
||||
self.maxq = torch.tensor(2**bits - 1)
|
||||
self.perchannel = perchannel
|
||||
self.sym = sym
|
||||
@@ -68,7 +67,7 @@ class Quantizer(nn.Module):
|
||||
self.zero = torch.round(-xmin / self.scale)
|
||||
|
||||
if self.mse:
|
||||
best = torch.full([x.shape[0]], float('inf'), device=dev)
|
||||
best = torch.full([x.shape[0]], float("inf"), device=dev)
|
||||
for i in range(int(self.maxshrink * self.grid)):
|
||||
p = 1 - i / self.grid
|
||||
xmin1 = p * xmin
|
||||
@@ -123,13 +122,12 @@ class Quantizer(nn.Module):
|
||||
try:
|
||||
import quant_cuda
|
||||
except:
|
||||
print('CUDA extension not installed.')
|
||||
print("CUDA extension not installed.")
|
||||
|
||||
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures):
|
||||
super().__init__()
|
||||
if bits not in [2, 3, 4, 8]:
|
||||
@@ -142,11 +140,11 @@ class QuantLinear(nn.Module):
|
||||
groupsize = groupsize if groupsize != -1 else infeatures
|
||||
self.groupsize = groupsize
|
||||
self.register_buffer(
|
||||
'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)),
|
||||
dtype=torch.int))
|
||||
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
|
||||
self.register_buffer('bias', torch.zeros(outfeatures))
|
||||
self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
|
||||
"qzeros", torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
|
||||
)
|
||||
self.register_buffer("scales", torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
|
||||
self.register_buffer("bias", torch.zeros(outfeatures))
|
||||
self.register_buffer("qweight", torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
|
||||
self._initialized_quant_state = False
|
||||
|
||||
def pack(self, linear, scales, zeros):
|
||||
@@ -161,8 +159,10 @@ class QuantLinear(nn.Module):
|
||||
for idx in range(self.infeatures):
|
||||
g_idx = idx // self.groupsize
|
||||
intweight.append(
|
||||
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,
|
||||
None])
|
||||
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[
|
||||
:, None
|
||||
]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
@@ -271,13 +271,13 @@ class QuantLinear(nn.Module):
|
||||
return y.reshape(outshape)
|
||||
|
||||
|
||||
def make_quant(module, names, bits, groupsize, name=''):
|
||||
def make_quant(module, names, bits, groupsize, name=""):
|
||||
if isinstance(module, QuantLinear):
|
||||
return
|
||||
for attr in dir(module):
|
||||
tmp = getattr(module, attr)
|
||||
name1 = name + '.' + attr if name != '' else attr
|
||||
name1 = name + "." + attr if name != "" else attr
|
||||
if name1 in names:
|
||||
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
|
||||
for name1, child in module.named_children():
|
||||
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
|
||||
make_quant(child, names, bits, groupsize, name + "." + name1 if name != "" else name1)
|
||||
|
@@ -9,8 +9,7 @@ def _noop(*args, **kwargs):
|
||||
|
||||
@contextmanager
|
||||
def low_resource_init():
|
||||
"""This context manager disables weight initialization and sets the default float dtype to half.
|
||||
"""
|
||||
"""This context manager disables weight initialization and sets the default float dtype to half."""
|
||||
old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
|
||||
old_uniform_ = torch.nn.init.uniform_
|
||||
old_normal_ = torch.nn.init.normal_
|
||||
|
Reference in New Issue
Block a user