ColossalAI/colossalai/inference/quant/smoothquant/models/parallel_linear.py
Xu Kai fd6482ad8c
[inference] Refactor inference architecture (#5057)
* [inference] support only TP (#4998)

* support only tp

* enable tp

* add support for bloom (#5008)

* [refactor] refactor gptq and smoothquant llama (#5012)

* refactor gptq and smoothquant llama

* fix import error

* fix linear import torch-int

* fix smoothquant llama import error

* fix import accelerate error

* fix bug

* fix import smooth cuda

* fix smoothcuda

* [Inference Refactor] Merge chatglm2 with pp and tp (#5023)

merge chatglm with pp and tp

* [Refactor] remove useless inference code (#5022)

* remove useless code

* fix quant model

* fix test import bug

* mv original inference legacy

* fix chatglm2

* [Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference

* [inference] update readme (#5051)

* update readme

* update readme

* fix architecture

* fix table

* fix table

* [inference] udpate example (#5053)

* udpate example

* fix run.sh

* fix rebase bug

* fix some errors

* update readme

* add some features

* update interface

* update readme

* update benchmark

* add requirements-infer

---------

Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
2023-11-19 21:05:05 +08:00

265 lines
10 KiB
Python

from typing import List, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import ParallelModule
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
def split_row_copy(smooth_linear, para_linear, tp_size=1, tp_rank=0, split_num=1):
qweights = smooth_linear.weight.split(smooth_linear.out_features // split_num, dim=0)
if smooth_linear.bias is not None:
bias = smooth_linear.bias.split(smooth_linear.out_features // split_num, dim=0)
smooth_split_out_features = para_linear.out_features // split_num
for i in range(split_num):
para_linear.weight[i * smooth_split_out_features : (i + 1) * smooth_split_out_features, :] = qweights[i][
tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features, :
]
if para_linear.bias is not None:
para_linear.bias[:, i * smooth_split_out_features : (i + 1) * smooth_split_out_features] = bias[i][
:, tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features
]
def split_column_copy(smooth_linear, para_linear, tp_rank=0, split_num=1):
qweights = smooth_linear.weight.split(smooth_linear.in_features // split_num, dim=-1)
smooth_split_in_features = para_linear.in_features // split_num
for i in range(split_num):
para_linear.weight[:, i * smooth_split_in_features : (i + 1) * smooth_split_in_features] = qweights[i][
:, tp_rank * smooth_split_in_features : (tp_rank + 1) * smooth_split_in_features
]
if smooth_linear.bias is not None:
para_linear.bias.copy_(smooth_linear.bias)
class RowW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__(in_features, out_features, alpha, beta)
self.process_group = None
self.tp_size = 1
self.tp_rank = 0
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
out_features = module.out_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if out_features < tp_size:
return module
if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = RowW8A8B8O8Linear(module.in_features, module.out_features // tp_size)
linear_1d.tp_size = tp_size
linear_1d.tp_rank = tp_rank
linear_1d.process_group = process_group
linear_1d.a = module.a.clone().detach()
linear_1d.b = module.b.clone().detach()
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
return linear_1d
class ColW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__(in_features, out_features, alpha, beta)
self.process_group = None
self.tp_size = 1
self.tp_rank = 0
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if in_features < tp_size:
return module
if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = ColW8A8B8O8Linear(module.in_features // tp_size, module.out_features)
linear_1d.tp_size = tp_size
linear_1d.tp_rank = tp_rank
linear_1d.process_group = process_group
linear_1d.a = torch.tensor(module.a)
linear_1d.b = torch.tensor(module.b)
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
if linear_1d.bias is not None:
linear_1d.bias = linear_1d.bias // tp_size
return linear_1d
@torch.no_grad()
def forward(self, x):
output = super().forward(x)
if self.tp_size > 1:
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
return output
class RowW8A8BFP32O32LinearSiLU(W8A8BFP32O32LinearSiLU, ParallelModule):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__(in_features, out_features, alpha, beta)
self.process_group = None
self.tp_size = 1
self.tp_rank = 0
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
out_features = module.out_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if out_features < tp_size:
return module
if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = RowW8A8BFP32O32LinearSiLU(module.in_features, module.out_features // tp_size)
linear_1d.tp_size = tp_size
linear_1d.tp_rank = tp_rank
linear_1d.process_group = process_group
linear_1d.a = module.a.clone().detach()
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
return linear_1d
class RowW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__(in_features, out_features, alpha, beta)
self.process_group = None
self.tp_size = 1
self.tp_rank = 0
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
out_features = module.out_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if out_features < tp_size:
return module
if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = RowW8A8BFP32OFP32Linear(module.in_features, module.out_features // tp_size)
linear_1d.tp_size = tp_size
linear_1d.tp_rank = tp_rank
linear_1d.process_group = process_group
linear_1d.a = module.a.clone().detach()
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
return linear_1d
class ColW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__(in_features, out_features, alpha, beta)
self.process_group = None
self.tp_size = 1
self.tp_rank = 0
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if in_features < tp_size:
return module
if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = ColW8A8BFP32OFP32Linear(module.in_features // tp_size, module.out_features)
linear_1d.tp_size = tp_size
linear_1d.tp_rank = tp_rank
linear_1d.process_group = process_group
linear_1d.a = module.a.clone().detach()
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
if linear_1d.bias is not None:
linear_1d.bias = linear_1d.bias / tp_size
return linear_1d
@torch.no_grad()
def forward(self, x):
output = super().forward(x)
if self.tp_size > 1:
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
return output