mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-29 04:40:36 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Linear8bit(nn.Linear):
|
||||
def __init__(
|
||||
@@ -12,11 +13,9 @@ class Linear8bit(nn.Linear):
|
||||
memory_efficient_backward=False,
|
||||
threshold=6.0,
|
||||
weight_data=None,
|
||||
bias_data=None
|
||||
bias_data=None,
|
||||
):
|
||||
super(Linear8bit, self).__init__(
|
||||
input_features, output_features, bias
|
||||
)
|
||||
super(Linear8bit, self).__init__(input_features, output_features, bias)
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.bias = bias_data
|
||||
self.state.threshold = threshold
|
||||
@@ -24,13 +23,12 @@ class Linear8bit(nn.Linear):
|
||||
self.state.memory_efficient_backward = memory_efficient_backward
|
||||
if threshold > 0.0 and not has_fp16_weights:
|
||||
self.state.use_pool = True
|
||||
|
||||
|
||||
self.register_parameter("SCB", nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
self.weight = weight_data
|
||||
self.quant()
|
||||
|
||||
|
||||
def quant(self):
|
||||
def quant(self):
|
||||
weight = self.weight.data.contiguous().half().cuda()
|
||||
CB, _, SCB, _, _ = bnb.functional.double_quant(weight)
|
||||
delattr(self, "weight")
|
||||
@@ -41,32 +39,34 @@ class Linear8bit(nn.Linear):
|
||||
|
||||
def forward(self, x):
|
||||
self.state.is_training = self.training
|
||||
|
||||
|
||||
if self.bias is not None and self.bias.dtype != torch.float16:
|
||||
self.bias.data = self.bias.data.half()
|
||||
|
||||
|
||||
self.state.CB = self.weight.data
|
||||
self.state.SCB = self.SCB.data
|
||||
|
||||
|
||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||
del self.state.CxB
|
||||
return out
|
||||
|
||||
|
||||
def replace_module(model):
|
||||
for name, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
replace_module(module)
|
||||
|
||||
if isinstance(module, nn.Linear) and "out_proj" not in name:
|
||||
if isinstance(module, nn.Linear) and "out_proj" not in name:
|
||||
model._modules[name] = Linear8bit(
|
||||
input_features=module.in_features,
|
||||
output_features=module.out_features,
|
||||
threshold=6.0,
|
||||
weight_data=module.weight,
|
||||
bias_data=module.bias,
|
||||
)
|
||||
input_features=module.in_features,
|
||||
output_features=module.out_features,
|
||||
threshold=6.0,
|
||||
weight_data=module.weight,
|
||||
bias_data=module.bias,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def getModelSize(model):
|
||||
param_size = 0
|
||||
param_sum = 0
|
||||
@@ -79,5 +79,5 @@ def getModelSize(model):
|
||||
buffer_size += buffer.nelement() * buffer.element_size()
|
||||
buffer_sum += buffer.nelement()
|
||||
all_size = (param_size + buffer_size) / 1024 / 1024
|
||||
print('Model Size: {:.3f}MB'.format(all_size))
|
||||
print("Model Size: {:.3f}MB".format(all_size))
|
||||
return (param_size, param_sum, buffer_size, buffer_sum, all_size)
|
||||
|
||||
Reference in New Issue
Block a user