diff --git a/LICENSE b/LICENSE index 06629068f..59d456c5b 100644 --- a/LICENSE +++ b/LICENSE @@ -428,3 +428,52 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + ---------------- LICENSE FOR AutoGPTQ ---------------- + + From AutoGPTQ: + + MIT License + + Copyright (c) 2023 潘其威(William) + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + ---------------- LICENSE FOR exllama ---------------- + + From exllama: + + MIT License + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/colossalai/inference/quant/gptq/__init__.py b/colossalai/inference/quant/gptq/__init__.py new file mode 100644 index 000000000..c035f3979 --- /dev/null +++ b/colossalai/inference/quant/gptq/__init__.py @@ -0,0 +1,4 @@ +from .cai_gptq import HAS_AUTO_GPTQ + +if HAS_AUTO_GPTQ: + from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear diff --git a/colossalai/inference/quant/gptq/cai_gptq/__init__.py b/colossalai/inference/quant/gptq/cai_gptq/__init__.py new file mode 100644 index 000000000..de57f2d8c --- /dev/null +++ b/colossalai/inference/quant/gptq/cai_gptq/__init__.py @@ -0,0 +1,13 @@ +import warnings + +HAS_AUTO_GPTQ = False +try: + import auto_gptq + HAS_AUTO_GPTQ = True +except ImportError: + warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ') + HAS_AUTO_GPTQ = False + +if HAS_AUTO_GPTQ: + from .cai_quant_linear import CaiQuantLinear, ColCaiQuantLinear, RowCaiQuantLinear + from .gptq_op import CaiGPTQLinearOp diff --git a/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py new file mode 100644 index 000000000..ca12c34ed --- /dev/null +++ b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py @@ -0,0 +1,354 @@ +# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ + +import math +import warnings +from typing import List, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.lazy import LazyInitContext +from colossalai.shardformer.layer import ParallelModule + +from .gptq_op import CaiGPTQLinearOp + +HAS_GPTQ_CUDA = False +try: + from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() + HAS_GPTQ_CUDA = True +except ImportError: + warnings.warn('CUDA gptq is not installed') + HAS_GPTQ_CUDA = False + + +class CaiQuantLinear(nn.Module): + + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): + super().__init__() + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize if groupsize != -1 else infeatures + + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer( + 'qzeros', + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) + self.register_buffer('scales', + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + if row_split: + self.register_buffer( + 'g_idx', + torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], + dtype=torch.int32)) + else: + self.register_buffer('g_idx', + torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) + + if bias: + self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) + else: + self.bias = None + + self.gptq_linear = CaiGPTQLinearOp(groupsize, bits) + + self.q4 = None + self.empty_tensor = torch.empty((1, 1), device="meta") + self.tp_size = tp_size + self.tp_rank = tp_rank + self.row_split = row_split + + def pack(self, linear, scales, zeros, g_idx=None): + + g_idx = g_idx.clone() if g_idx is not None else torch.tensor( + [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + half_scales = scales.clone().half() + # print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape) + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + wn = 8 + pbits = 32 + ptype = torch.int32 + unsign_type = np.uint32 + sign_type = np.int32 + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, + None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(unsign_type) + qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type) + + i = 0 + row = 0 + + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (pbits // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += pbits // self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qweight = qweight.astype(sign_type) + qweight1 = torch.from_numpy(qweight) + qweight1 = qweight1.contiguous() #.to("cuda") + self.qweight.data.copy_(qweight1) + + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type) + zeros -= 1 + zeros = zeros.numpy().astype(unsign_type) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (pbits // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += pbits // self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qzeros = qzeros.astype(sign_type) + qzeros = torch.from_numpy(qzeros) + qzeros = qzeros + self.qzeros.data.copy_(qzeros) + + if torch.equal(self.g_idx.to(g_idx.device), g_idx): + self.g_idx = None + else: + self.g_idx = g_idx + + def init_q4(self): + assert self.qweight.device.type == "cuda" + self.q4_width = self.qweight.shape[1] + if self.g_idx is not None: + if self.row_split and torch.equal( + self.g_idx, + torch.tensor( + [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device)): + self.g_idx = None + elif torch.equal( + self.g_idx, + torch.tensor([i // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device)): + self.g_idx = None + + if self.g_idx is not None: + g_idx = self.g_idx.to("cpu") + else: + g_idx = self.empty_tensor + + self.q4 = gptq_cuda.make_q4(self.qweight, self.qzeros, self.scales, g_idx, torch.cuda.current_device()) + torch.cuda.synchronize() + + def forward(self, x): + outshape = x.shape[:-1] + (self.outfeatures,) + + if HAS_GPTQ_CUDA and self.bits == 4: + + if self.q4 is None: + self.init_q4() + + x = x.view(-1, x.shape[-1]) + output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device) + gptq_cuda.q4_matmul(x.half(), self.q4, output) + if self.bias is not None and (not self.row_split or self.tp_size == 1): + output.add_(self.bias) + else: + if self.bias is not None and (not self.row_split or self.tp_size == 1): + bias = self.bias + else: + bias = None + output = self.gptq_linear( + x, + self.qweight, + self.scales, + self.qzeros, + g_idx=self.g_idx, + bias=bias, + ) + return output.view(outshape) + + +def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1): + + qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1) + qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1) + scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1) + g_idx = gptq_linear.g_idx + if gptq_linear.bias is not None: + bias = gptq_linear.bias.split(gptq_linear.out_features // split_num, dim=-1) + + cai_split_out_features = cai_linear.outfeatures // split_num + zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num + + for i in range(split_num): + cai_linear.qweight[:, i * cai_split_out_features:(i + 1) * + cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * + cai_split_out_features] + cai_linear.qzeros[:, i * zero_split_block:(i + 1) * + zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block] + cai_linear.scales[:, i * cai_split_out_features:(i + 1) * + cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * + cai_split_out_features] + if cai_linear.bias is not None: + cai_linear.bias[i * cai_split_out_features:(i + 1) * + cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) * + cai_split_out_features] + + cai_linear.g_idx.copy_(g_idx) + + +def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): + + qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0) + qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0) + scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0) + g_idxs = gptq_linear.g_idx.split(gptq_linear.in_features // split_num, dim=0) + + cai_split_in_features = cai_linear.infeatures // (32 // cai_linear.bits) // split_num + zero_split_block = cai_linear.infeatures // cai_linear.groupsize // split_num + idx_split_features = cai_linear.infeatures // split_num + + for i in range(split_num): + cai_linear.qweight[i * cai_split_in_features:(i + 1) * + cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) * + cai_split_in_features, :] + cai_linear.qzeros[i * zero_split_block:(i + 1) * + zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) * + zero_split_block, :] + cai_linear.scales[i * zero_split_block:(i + 1) * + zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) * + zero_split_block, :] + cai_linear.g_idx[i * idx_split_features:(i + 1) * + idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) * + idx_split_features] + if cai_linear.bias is not None: + cai_linear.bias.copy_(gptq_linear.bias) + + +class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): + + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): + + super().__init__(bits, + groupsize, + infeatures, + outfeatures, + bias, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=row_split) + self.process_group = None + + @staticmethod + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + tp_rank = dist.get_rank(process_group) + + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = RowCaiQuantLinear(module.bits, + module.group_size, + module.in_features // tp_size, + module.out_features, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=True) + linear_1d.process_group = process_group + + split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) + return linear_1d + + def forward(self, x): + output = super().forward(x) + if self.tp_size > 1: + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) + if self.bias is not None: + output.add_(self.bias) + return output + + +class ColCaiQuantLinear(CaiQuantLinear, ParallelModule): + + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): + + super().__init__(bits, + groupsize, + infeatures, + outfeatures, + bias, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=row_split) + self.process_group = None + + @staticmethod + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + tp_rank = dist.get_rank(process_group) + + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = ColCaiQuantLinear(module.bits, + module.group_size, + module.in_features, + module.out_features // tp_size, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank) + linear_1d.process_group = process_group + + split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) + return linear_1d diff --git a/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py b/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py new file mode 100644 index 000000000..a8902eb35 --- /dev/null +++ b/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py @@ -0,0 +1,58 @@ +import torch + +from colossalai.kernel.triton import gptq_fused_linear_triton + + +class CaiGPTQLinearOp(torch.nn.Module): + def __init__(self, gptq_group_size, gptq_quant_bits): + super(CaiGPTQLinearOp, self).__init__() + self.group_size = gptq_group_size + self.bits = gptq_quant_bits + self.maxq = 2**self.bits - 1 + self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device()) + + def forward( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zeros: torch.Tensor, + g_idx: torch.Tensor = None, + act_type=0, + bias: torch.Tensor = None, + residual: torch.Tensor = None, + qkv_fused=False, + ): + add_bias = True + if bias is None: + bias = self.empty_tensor + add_bias = False + + add_residual = True + if residual is None: + residual = self.empty_tensor + add_residual = False + x = input.view(-1, input.shape[-1]) + + out = gptq_fused_linear_triton( + x, + weight, + weight_scales, + weight_zeros, + bias, + residual, + self.bits, + self.maxq, + self.group_size, + qkv_fused, + add_bias, + add_residual, + act_type=act_type, + g_idx=g_idx, + ) + if qkv_fused: + out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1]) + else: + out = out.view(input.shape[0], input.shape[1], weight.shape[-1]) + + return out diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 1335f13d6..29b5d6117 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -1,6 +1,7 @@ from typing import Any, Callable, List, Optional, Union import torch +import torch.distributed as dist import torch.nn as nn from transformers import BloomForCausalLM, LlamaForCausalLM from transformers.generation import GenerationConfig @@ -68,6 +69,13 @@ class TPInferEngine: self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None + self.max_dq_buffer_size = 1 + self.max_inner_outer_dim = 1 + self.gptq_temp_state_buffer = None + self.gptq_temp_dq_buffer = None + self.bits = -1 + self.use_act_order = False + self.shard_config = shard_config self.model = None # optimize the original model by sharding with ShardFormer @@ -81,6 +89,50 @@ class TPInferEngine: self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num ) + def _post_init_gptq_buffer(self, model: nn.Module) -> None: + from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear + HAS_GPTQ_CUDA = False + try: + from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() + HAS_GPTQ_CUDA = True + except ImportError: + warnings.warn('CUDA gptq is not installed') + HAS_GPTQ_CUDA = False + + for name, submodule in model.named_modules(): + if isinstance(submodule, CaiQuantLinear): + self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8) + + if self.use_act_order: + self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures, + submodule.outfeatures) + self.bits = submodule.bits + if not (HAS_GPTQ_CUDA and self.bits == 4): + return + + max_input_len = 1 + if self.use_act_order: + max_input_len = self.max_input_len + # The temp_state buffer is required to reorder X in the act-order case. + # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim), + dtype=torch.float16, + device=torch.cuda.current_device()) + self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size), + dtype=torch.float16, + device=torch.cuda.current_device()) + + gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, + self.gptq_temp_dq_buffer) + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + torch.cuda.empty_cache() + def _optimize_model(self, model: nn.Module) -> None: """ Optimize the original model by sharding with ShardFormer. @@ -129,6 +181,10 @@ class TPInferEngine: assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." policy = get_autopolicy(model, inference_only=True) self.model, _ = shardformer.optimize(model, policy) + + if self.shard_config.inference_gptq: + self._post_init_gptq_buffer(model) + self.model = self.model.cuda() @property diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py index 2d18a3922..3d6df2097 100644 --- a/colossalai/inference/tensor_parallel/policies/bloom.py +++ b/colossalai/inference/tensor_parallel/policies/bloom.py @@ -3,6 +3,9 @@ from functools import partial import torch from torch.nn import LayerNorm +import colossalai.shardformer.layer as col_nn +from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy from ..modeling.bloom import BloomInferenceForwards @@ -35,6 +38,35 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy): from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel policy = super().module_policy() + if self.shard_config.inference_gptq: + from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear + policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ + "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 3}), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=RowCaiQuantLinear, + kwargs={'split_num': 1}), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 1}), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=RowCaiQuantLinear, + kwargs={'split_num': 1}), + ]) # NOTE set inference mode to shard config self.shard_config._infer() diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 9bbb547db..eaaadadd1 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -3,6 +3,8 @@ from functools import partial import torch from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm +from colossalai.shardformer.layer import VocabParallelEmbedding1D +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription # import colossalai from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy @@ -34,6 +36,55 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): def module_policy(self): policy = super().module_policy() + + if self.shard_config.inference_gptq: + from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear + + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=RowCaiQuantLinear, + kwargs={'split_num': 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=RowCaiQuantLinear, + kwargs={'split_num': 1}, + ) + ], + ) + self.shard_config._infer() infer_forward = LlamaInferenceForwards.llama_model_forward diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu new file mode 100644 index 000000000..2b1b366b1 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu @@ -0,0 +1,63 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "column_remap.cuh" +#include "util.cuh" + +const int SHUF_BLOCKSIZE_X = 256; +const int SHUF_BLOCKSIZE_Y = 16; + +__global__ void column_remap_kernel +( + const half* __restrict__ x, + half* __restrict__ x_new, + const int x_width, + const int x_height, + const uint32_t* x_map +) +{ + int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; + if (x_column >= x_width) return; + //if (x_row >= x_height) return; + + int x_stride = x_width; + int x_idx = x_row * x_stride + x_column; + + int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); + int x_idx_end = x_row_end * x_stride + x_column; + + int s_column = x_map[x_column]; + int s_idx = x_row * x_stride + s_column; + + while (x_idx < x_idx_end) + { + x_new[x_idx] = x[s_idx]; + x_idx += x_stride; + s_idx += x_stride; + } +} + +// Remap columns in x to correspond to sequential group index before matmul +// +// perform x -> seq_x such that seq_x @ seq_w == x @ w + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +) +{ + dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); + + dim3 blocks + ( + (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, + (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, + 1 + ); + + column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh new file mode 100644 index 000000000..6571c17d6 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh @@ -0,0 +1,19 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _column_remap_cuh +#define _column_remap_cuh + +#include +#include +#include + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +); + +#endif \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh new file mode 100644 index 000000000..c5258813e --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh @@ -0,0 +1,58 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_compat_cuh +#define _cuda_compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu new file mode 100644 index 000000000..4416027c8 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu @@ -0,0 +1,75 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#define _cuda_buffers_cu +#include "cuda_buffers.cuh" + +CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; +// __constant__ half2 q4_table[16][256]; +// half2 q4_table_host[16][256]; +// bool q4_table_init = false; + +CudaBuffers::CudaBuffers +( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq +) : + device(_device), + temp_state_size(_temp_state_size), + temp_state(_temp_state), + temp_dq(_temp_dq) +{ + cudaSetDevice(_device); + + cudaStreamCreate(&alt_stream_1); + cudaStreamCreate(&alt_stream_2); + cudaStreamCreate(&alt_stream_3); + cudaEventCreate(&alt_stream_1_done); + cudaEventCreate(&alt_stream_2_done); + cudaEventCreate(&alt_stream_3_done); +} + +CudaBuffers::~CudaBuffers() +{ + cudaStreamDestroy(alt_stream_1); + cudaStreamDestroy(alt_stream_2); + cudaStreamDestroy(alt_stream_3); + cudaEventDestroy(alt_stream_1_done); + cudaEventDestroy(alt_stream_2_done); + cudaEventDestroy(alt_stream_3_done); +} + +CudaBuffers* get_buffers(const int device_index) +{ + return g_buffers[device_index]; +} + +void prepare_buffers_cuda +( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq +) +{ + CudaBuffers* buffers = new CudaBuffers + ( + _device, + _temp_state_size, + _temp_state, + _temp_dq + ); + + g_buffers[_device] = buffers; +} + +void cleanup_buffers_cuda() +{ + for (int i = 0; i < CUDA_MAX_DEVICES; i++) + { + if (!g_buffers[i]) continue; + delete g_buffers[i]; + g_buffers[i] = NULL; + } +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh new file mode 100644 index 000000000..0bf2057c6 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh @@ -0,0 +1,55 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_buffers_cuh +#define _cuda_buffers_cuh + +#include +#include +#include +#include + +const int CUDA_MAX_DEVICES = 16; + +// #ifndef _cuda_buffers_cu +// extern __constant__ half2 q4_table[16][256]; +// #endif + +class CudaBuffers +{ +public: + int device; + + half* temp_state; // [max_hidden_rows * intermediate_size] + int temp_state_size; + half* temp_dq; // size of largest quant tensor * 8 + + cudaStream_t alt_stream_1; + cudaStream_t alt_stream_2; + cudaStream_t alt_stream_3; + cudaEvent_t alt_stream_1_done; + cudaEvent_t alt_stream_2_done; + cudaEvent_t alt_stream_3_done; + + CudaBuffers + ( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq + ); + ~CudaBuffers(); +}; + +CudaBuffers* get_buffers(const int device_index); + +void prepare_buffers_cuda +( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq +); + +void cleanup_buffers_cuda(); + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh new file mode 100644 index 000000000..5cd2e8553 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh @@ -0,0 +1,49 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _hip_compat_cuh +#define _hip_compat_cuh + +// Workaround for a bug in hipamd, backported from upstream. +__device__ __forceinline__ __half __compat_hrcp(__half x) { + return __half_raw{ + static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; +} + +__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { + return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)), + static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))}; +} + +#define hrcp __compat_hrcp +#define h2rcp __compat_h2rcp + +// Workaround for hipify_python using rocblas instead of hipblas. +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, + hipblasOperation_t transA, + hipblasOperation_t transB, + int m, + int n, + int k, + const half* alpha, + const half* AP, + int lda, + const half* BP, + int ldb, + const half* beta, + half* CP, + int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} + +#define rocblas_handle hipblasHandle_t +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_get_stream hipblasGetStream +#define rocblas_set_stream hipblasSetStream +#define rocblas_hgemm __compat_hipblasHgemm + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp new file mode 100644 index 000000000..bcc0e4390 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp @@ -0,0 +1,254 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include +#include +#include +#include +#include +#include +#include +#include "util.cuh" +#include "tuning.h" +#include "cuda_buffers.cuh" +#include "q4_matrix.cuh" +#include "q4_matmul.cuh" +#include "column_remap.cuh" + +// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a +// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of +// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. + +void check_cuda(cudaError_t ret) +{ + switch (ret) + { + case cudaSuccess: + break; + + case cudaUnspecified: + printf(" **** Unspecified error\n"); + TORCH_CHECK(false, "CUDA error"); + break; + + default: + printf(" **** CUDA error\n"); \ + printf(" **** %s\n", cudaGetErrorString(ret)); \ + TORCH_CHECK(false, "CUDA error"); \ + break; + } +} + +// Some decluttering macros + +#define STRINGIFY_(__x) #__x +#define STRINGIFY(__x) STRINGIFY_(__x) +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) +#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") + +#define TORCH_CHECK_DEVICE_INDEX(__index) \ +do { \ + TORCH_CHECK(__index >= 0, "no device index"); \ + TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ +} while(0) + +#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ +do { \ + TORCH_CHECK_DTYPE(__w, kInt); \ + TORCH_CHECK_DTYPE(__w_scales, kHalf); \ + TORCH_CHECK_DTYPE(__w_zeros, kInt); \ + TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ + TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ + TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ + TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ +} while(0) + +int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) +{ + int groupsize = w.size(0) * 8 / w_zeros.size(0); + TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") + return groupsize; +} + + +// Tuning parameters + +ExLlamaTuning tuningParams; + +void set_tuning_params +( + int matmul_recons_thd, + bool matmul_fused_remap, + bool matmul_no_half2 +) +{ + tuningParams.matmul_recons_thd = matmul_recons_thd; + tuningParams.matmul_fused_remap = matmul_fused_remap; + tuningParams.matmul_no_half2 = matmul_no_half2; +} + + +// Release all unmanaged objects allocated by the extension + +void cleanup() +{ + cleanup_buffers_cuda(); + g_q4_free_matrices(); +} + + +// Prepare buffers for forward pass + +void prepare_buffers +( + torch::Device device, + torch::Tensor temp_state, + torch::Tensor temp_dq +) +{ + int device_index = device.index(); + TORCH_CHECK_DEVICE_INDEX(device_index); + const at::cuda::OptionalCUDAGuard device_guard(device); + + prepare_buffers_cuda + ( + device_index, + // buffer size used for sanity checks + temp_state.numel(), + (half*) temp_state.data_ptr(), + (half*) temp_dq.data_ptr() + ); +} + + +// Create Q4Matrix, return handle + +uintptr_t make_q4 +( + torch::Tensor qweight, + torch::Tensor qzeros, + torch::Tensor scales, + torch::Tensor g_idx, + int device +) +{ + TORCH_CHECK_DTYPE(qweight, kInt); + TORCH_CHECK_DTYPE(qzeros, kInt); + TORCH_CHECK_DTYPE(scales, kHalf); + TORCH_CHECK_DTYPE_OPT(g_idx, kInt); + TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); + TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); + TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); + + int width = qweight.size(1); + int height = qweight.size(0) * 8; + int groups = qzeros.size(0); + + Q4Matrix* m = new Q4Matrix + ( + height, + width, + groups, + + (uint32_t*) qweight.data_ptr(), + (uint32_t*) qzeros.data_ptr(), + (half*) scales.data_ptr(), + g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), + + device + ); + + g_q4_keep_matrix(m); + return reinterpret_cast (m); +} + + +// Matmul half @ quant -> half + +void q4_matmul +( + torch::Tensor x, + uintptr_t w, + torch::Tensor out +) +{ + Q4Matrix* wm = reinterpret_cast (w); + + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(out, kHalf); + TORCH_CHECK_SHAPES(x, 0, out, 0, 1); + TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + int x_height = x.size(0); + + if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) + { + q4_matmul_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr() + ); + } + else + { + q4_matmul_recons_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr(), + at::cuda::getCurrentCUDABlasHandle() + ); + } +} + + +// Remap columns in half tensor + +void column_remap +( + torch::Tensor x, + torch::Tensor x_new, + torch::Tensor x_map +) +{ + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(x_new, kHalf); + TORCH_CHECK_DTYPE(x_map, kInt); + TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); + + int height = x.size(0); + int width = x.size(1); + + TORCH_CHECK_BUFFER_SIZE(x_new, height * width); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + column_remap_cuda + ( + (half*) x.data_ptr(), + (half*) x_new.data_ptr(), + height, + width, + (uint32_t*) x_map.data_ptr() + ); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); + m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); + m.def("cleanup", &cleanup, "cleanup"); + m.def("make_q4", &make_q4, "make_q4"); + m.def("q4_matmul", &q4_matmul, "q4_matmul"); +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh new file mode 100644 index 000000000..2fd5ab0b3 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh @@ -0,0 +1,294 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _matrix_cuh +#define _matrix_cuh + +#include +#include + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } +}; + +class MatrixView_q4_column +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +}; + +// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale + +__device__ __forceinline__ half2 dot_product_8 +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + +// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) +// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; +// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; +// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; + + half2 tmp = __hmul2(*h_ptr++, v_01); + tmp = __hfma2(*h_ptr++, v_23, tmp); + tmp = __hfma2(*h_ptr++, v_45, tmp); + tmp = __hfma2(*h_ptr++, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half* h_ptr = h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(*h_ptr++, v_0); + tmp = __hfma(*h_ptr++, v_1, tmp); + tmp = __hfma(*h_ptr++, v_2, tmp); + tmp = __hfma(*h_ptr++, v_3, tmp); + tmp = __hfma(*h_ptr++, v_4, tmp); + tmp = __hfma(*h_ptr++, v_5, tmp); + tmp = __hfma(*h_ptr++, v_6, tmp); + tmp = __hfma(*h_ptr++, v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map + +__device__ __forceinline__ half2 dot_product_8_x_map +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + + half h_0 = h_ptr[*x_map_ptr++]; + half h_1 = h_ptr[*x_map_ptr++]; + half h_2 = h_ptr[*x_map_ptr++]; + half h_3 = h_ptr[*x_map_ptr++]; + half h_4 = h_ptr[*x_map_ptr++]; + half h_5 = h_ptr[*x_map_ptr++]; + half h_6 = h_ptr[*x_map_ptr++]; + half h_7 = h_ptr[*x_map_ptr++]; + + half2 h_01 = __halves2half2(h_0, h_1); + half2 h_23 = __halves2half2(h_2, h_3); + half2 h_45 = __halves2half2(h_4, h_5); + half2 h_67 = __halves2half2(h_6, h_7); + + half2 tmp = __hmul2(h_01, v_01); + tmp = __hfma2(h_23, v_23, tmp); + tmp = __hfma2(h_45, v_45, tmp); + tmp = __hfma2(h_67, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_x_map_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); + tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu new file mode 100644 index 000000000..f47daeb0e --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu @@ -0,0 +1,260 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "q4_matmul.cuh" +#include "column_remap.cuh" +#include "util.cuh" +#include "matrix.cuh" +#include "cu_compat.cuh" +#include "cuda_buffers.cuh" +#if defined(USE_ROCM) +#include "hip_compat.cuh" +#endif + +const int THREADS_X = 32; // Block size and thread count along columns in w and out +const int THREADS_Y = 1; // Block size and thread count along rows in x and out + +typedef void (*fp_q4_matmul_kernel) +( + const half*, + const uint32_t*, + half*, + const half*, + const uint32_t*, + const int, + const int, + const int, + const int, + const int, + const uint32_t*, + bool +); + +template +__global__ void q4_matmul_kernel +( + const half* __restrict__ x, + const uint32_t* __restrict__ w, + half* __restrict__ out, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int dim, + const int width, + const int groupsize, + const int block_size_z, + const uint32_t* __restrict__ x_map, + bool no_zero +) +{ + // Start of block + + int x_column = block_size_z * blockIdx.z; + int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); + + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + int x_row = THREADS_Y * blockIdx.y + threadIdx.y; + + int iterations = (x_column_end - x_column) / 8; + + // Views + + MatrixView_half x_(x, height, dim); + MatrixView_half w_scales_(w_scales, dim / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); + MatrixView_q4_column w_(w, dim, width); + MatrixView_half_rw out_(out, height, width); + + // Zero output + + if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) + { + *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; + __syncthreads(); + } + + // Loop over part of x row (and w column) + + half2 acc = {}; + half acc_h = {}; + + if constexpr (use_groupsize) + { + // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this + // could be slightly faster + + for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) + { + if constexpr (use_half2) + { + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + else + { + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + } + } + else + { + // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache + + for (int k = x_column; k < x_column + iterations * 8; k += 8) + { + if constexpr (use_half2) + { + int group = k / groupsize; + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + else + { + int group = k / groupsize; + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + } + } + + // Add to block result + + if constexpr (use_half2) + { + half result = __hadd(__low2half(acc), __high2half(acc)); + atomicAdd(out_.item_ptr(x_row, w_column), result); + } + else + { + atomicAdd(out_.item_ptr(x_row, w_column), acc_h); + } +} + +fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) +{ + // + if (tuningParams->matmul_no_half2) { + if (block_size_z % groupsize == 0) { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } else { + if (block_size_z % groupsize == 0) + { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } +}; + +// Compute y = x @ w + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero, + cudaStream_t alt_stream +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + + uint32_t* x_map = w->cuda_x_map; + const half* x_mapped = x; + if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) + { + CudaBuffers* buffers = get_buffers(w->device); + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + x_map = NULL; + } + + int block_size_z; + if (w->width == 4096) block_size_z = 384; // 7B + else if (w->width == 11008) block_size_z = 256; + else if (w->width == 5120) block_size_z = 384; // 13B + else if (w->width == 13824) block_size_z = 256; + else if (w->width == 6656) block_size_z = 256; // 33B + else if (w->width == 17920) block_size_z = 128; + else block_size_z = 256; + + //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); + + dim3 threads(THREADS_X, THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height + threads.y - 1) / threads.y, + (dim + block_size_z - 1) / block_size_z + ); + + fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); + + kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); +} + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + const cublasHandle_t handle, + bool no_zero +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + CudaBuffers* buffers = get_buffers(w->device); + + const half* x_mapped = x; + if (w->cuda_x_map) + { + TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small"); + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + } + + w->reconstruct(buffers->temp_dq); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 + const float alpha = 1.0f; + const float beta = no_zero ? 1.0f : 0.0f; + cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, + x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); +#else + const half alpha = __float2half(1.0f); + const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); + cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); +#endif +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh new file mode 100644 index 000000000..09f3e1a63 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh @@ -0,0 +1,43 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matmul_cuh +#define _q4_matmul_cuh + +#include +#include +#include +#include +#include + +#include "q4_matrix.cuh" +#include "tuning.h" + +// Workaround for hipify_python using rocblas instead of hipblas. +#if defined(USE_ROCM) +#include +#define rocblas_handle hipblasHandle_t +#endif + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero = false, + cudaStream_t alt_stream = NULL +); + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + const cublasHandle_t handle, + bool no_zero = false +); + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu new file mode 100644 index 000000000..9c61143f5 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu @@ -0,0 +1,225 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "q4_matrix.cuh" +#include +#include "util.cuh" +#include "matrix.cuh" + +using namespace std; + +const int UNSHUF_BLOCKSIZE_X = 64; + +const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column +const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows + +vector g_q4_matrices; + +void g_q4_keep_matrix(Q4Matrix* m) +{ + g_q4_matrices.push_back(m); +} + +void g_q4_free_matrices() +{ + for (const auto& m : g_q4_matrices) delete m; + g_q4_matrices.clear(); +} + +Q4Matrix::Q4Matrix +( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device +) : + height(_height), + width(_width), + groups(_groups), + device(_device) +{ + cudaSetDevice(device); + + cuda_qweight = _qweight; + cuda_qzeros = _qzeros; + cuda_scales = _scales; + + groupsize = height / groups; + + if (_g_idx) make_sequential(_g_idx); +} + +Q4Matrix::~Q4Matrix() +{ +} + +// Make sequential + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const uint32_t* __restrict__ x_map, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + + int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + + int w_new2_row = blockIdx.y; + + int x_map_idx = w_new2_row << 3; + + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = x_map[x_map_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) +{ + uint32_t* cuda_new_qweight = NULL; + cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch + + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + + // Group histogram + + for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; + + // Group map + + for (int i = 0, acc = 0; i < groups; i++) + { + short tmp = cpu_g_idx_map[i]; + cpu_g_idx_map[i] = acc; + acc += tmp; + } + + // X map (inverse) + + for (int row = 0; row < height; row++) + { + uint32_t target_group = cpu_g_idx[row]; + uint32_t target_row = cpu_g_idx_map[target_group]; + cpu_g_idx_map[target_group]++; + cpu_x_map_inv[row] = target_row; + } + + // X map + + for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; + + // Move to CUDA + + cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); + + // Rearrange rows in w + + dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); + dim3 blocks + ( + (width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2), + height / 8, + 1 + ); + + make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); + + // Replace qweights + + cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + + // Cleanup + + cudaDeviceSynchronize(); + cudaFree(cuda_new_qweight); + free(cpu_g_idx_map); + free(cpu_x_map); + free(cpu_x_map_inv); +} + +__global__ void reconstruct_kernel +( + const uint32_t* __restrict__ w, + half* __restrict__ out, // (y) + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int width, + const int groupsize +) +{ + // Start of block + + int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; + int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; + if (column >= width) return; + + // Views + + MatrixView_q4_column w_(w, height, width); + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, height / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); + + // Groupsize version + + int group = row / groupsize; + + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + + uint32_t w_read = w_.item_uint32_t(row, column); + half* out_ptr = out_.item_ptr(row, column); + + #pragma unroll + for (int s = 0; s < 32; s += 4) + { + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); + *out_ptr = w_item; out_ptr += out_.width; + } +} + +void Q4Matrix::reconstruct(half* out) +{ + dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height / 8 + threads.y - 1) / threads.y, + 1 + ); + + reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh new file mode 100644 index 000000000..50cb72a41 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh @@ -0,0 +1,53 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matrix_cuh +#define _q4_matrix_cuh + +#include +#include +#include + +class Q4Matrix +{ +public: + + int device; + + int height; + int width; + int groups; + int groupsize; + + uint32_t* cuda_qweight = NULL; + uint32_t* cuda_qzeros = NULL; + half* cuda_scales = NULL; + uint32_t* cuda_x_map = NULL; + + Q4Matrix + ( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device + ); + + ~Q4Matrix(); + + void reconstruct(half* out); + +private: + + void make_sequential(const uint32_t* cpu_g_idx); + +}; + +void g_q4_keep_matrix(Q4Matrix* m); +void g_q4_free_matrices(); + +#endif \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h new file mode 100644 index 000000000..770ca46aa --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h @@ -0,0 +1,13 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _tuning_h +#define _tuning_h + +struct ExLlamaTuning +{ + int matmul_recons_thd; + bool matmul_fused_remap; + bool matmul_no_half2; +}; + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh new file mode 100644 index 000000000..7b3975732 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh @@ -0,0 +1,33 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _util_cuh +#define _util_cuh + +#include +#include +#include +#include + +#if defined(USE_ROCM) +#define cudaUnspecified hipErrorUnknown +#else +#define cudaUnspecified cudaErrorApiFailureBase +#endif + +// React to failure on return code != cudaSuccess + +#define _cuda_check(fn) \ +do { \ + {_cuda_err = fn;} \ + if (_cuda_err != cudaSuccess) goto _cuda_fail; \ +} while(false) + +// React to failure on return code == 0 + +#define _alloc_check(fn) \ +do { \ + if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ + else _cuda_err = cudaSuccess; \ +} while(false) + +#endif diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index bc68a07e6..87ea9cf65 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -6,6 +6,7 @@ try: from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd from .copy_kv_cache_dest import copy_kv_cache_to_dest from .fused_layernorm import layer_norm + from .gptq_triton import gptq_fused_linear_triton from .rms_norm import rmsnorm_forward from .rotary_embedding_kernel import rotary_embedding_fwd from .softmax import softmax @@ -20,6 +21,7 @@ try: "copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd", + "gptq_fused_linear_triton", ] except ImportError: diff --git a/colossalai/kernel/triton/gptq_triton.py b/colossalai/kernel/triton/gptq_triton.py new file mode 100644 index 000000000..cf4ef183a --- /dev/null +++ b/colossalai/kernel/triton/gptq_triton.py @@ -0,0 +1,541 @@ +# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ + +import torch +import triton +import triton.language as tl +from auto_gptq.nn_modules.triton_utils import custom_autotune + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def cosh(x): + exp_x = tl.exp(x) + return (exp_x + 1.0 / exp_x) * 0.5 + + +# a Triton implementation of the most used activations +# See for instance http://arxiv.org/abs/1606.08415 for an overview + + +# ReLU +@triton.jit +def relu(x): + """ + ReLU_ activation function + + .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html + """ + return tl.where(x >= 0, x, 0.0) + + +@triton.jit +def squared_relu(x): + """ + Squared ReLU activation, as proposed in the Primer_ paper. + + .. _Primer: https://arxiv.org/abs/2109.08668 + """ + x_sq = x * x + return tl.where(x > 0.0, x_sq, 0.0) + + +@triton.jit +def star_relu(x): + """ + Star ReLU activation, as proposed in the "MetaFormer Baselines for Vision"_ paper. + + .. _ "MetaFormer Baselines for Vision": https://arxiv.org/pdf/2210.13452.pdf + """ + x_sq = x * x + return 0.8944 * tl.where(x > 0.0, x_sq, 0.0) - 0.4472 + + +# Leaky ReLU +@triton.jit +def leaky_relu(x): + """ + LeakyReLU_ activation + + .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html + """ + return tl.where(x >= 0.0, x, 0.01 * x) + + +@triton.jit +def gelu(x): + """ + GeLU_ activation - Gaussian error linear unit + + .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf + """ + return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x))) + + +@triton.jit +def smelu(x): + """ + SmeLU_ activation - Smooth ReLU with beta=2.0 + + .. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf + """ + beta = 2.0 + + relu = tl.where(x >= beta, x, 0.0) + return tl.where(tl.abs(x) <= beta, (x + beta) * (x + beta) / (4.0 * beta), relu) + + +@triton.jit +def silu(x): + return x * tl.sigmoid(x) + + +@custom_autotune.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4 + ), + ], + key=["M", "N", "K"], + nearest_power_of_two=True, + prune_configs_by={ + "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, + "perf_model": None, + "top_k": None, + }, +) +@triton.jit +def cai_gptq_matmul_248_kernel( + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + bias_ptr, + residual_ptr, + M, + N, + K, + bits, + maxq, + gptq_group_size, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales, + stride_zeros, + QKV_FUSED: tl.constexpr, + ADD_BIAS: tl.constexpr, + ADD_RESIDUAL: tl.constexpr, + ACT_TYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + NK = K + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) + qkv_offset = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + # offs_bk = offs_k + qkv_offset * NK + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + + a_mask = offs_am[:, None] < M + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = ( + b_ptr + + qkv_offset * N * NK // infearure_per_bits + + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + # g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :] + zeros_ptrs = ( + zeros_ptr + + qkv_offset * NK * N // gptq_group_size // infearure_per_bits + + (offs_bn[None, :] // infearure_per_bits) + ) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + g_idx_base = tl.arange(0, BLOCK_SIZE_K) + g_idx_base = g_idx_base // gptq_group_size + g_idx = g_idx_base + # tl.device_print("gidx, ", g_idx) + + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = zeros + 1 + + for k in range(0, num_pid_k): + # g_idx = tl.load(g_ptrs) + # if (k + 1) * BLOCK_SIZE_K > currend_group_end: + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = zeros + 1 + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros).to(tl.float16) * scales # Scale and shift + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_idx = g_idx_base + ((k + 1) * BLOCK_SIZE_K) // gptq_group_size + # if (k + 2) * BLOCK_SIZE_K > currend_group_end: + + c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + + if ADD_BIAS: + bias_mask = offs_bn < N + offs_bn += qkv_offset * N + bias_ptrs = bias_ptr + stride_cn * offs_bn + bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + accumulator += bias[None, :] + + if ACT_TYPE == 1: + accumulator = relu(accumulator) + elif ACT_TYPE == 2: + accumulator = gelu(accumulator) + elif ACT_TYPE == 3: + accumulator = silu(accumulator) + + if ADD_RESIDUAL: + residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + res = tl.load(residual_ptrs, mask=c_mask, other=0.0) + accumulator += res + + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@custom_autotune.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4 + ), + ], + key=["M", "N", "K"], + nearest_power_of_two=True, + prune_configs_by={ + "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, + "perf_model": None, + "top_k": None, + }, +) +@triton.jit +def cai_gptq_idx_matmul_248_kernel( + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + idx_ptr, + bias_ptr, + residual_ptr, + M, + N, + K, + bits, + maxq, + gptq_group_size, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales, + stride_zeros, + QKV_FUSED: tl.constexpr, + ADD_BIAS: tl.constexpr, + ADD_RESIDUAL: tl.constexpr, + ACT_TYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + NK = K + + # if QKV_FUSED: + # NK = K//3 + # else: + # NK = K + # NK = K + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) + qkv_offset = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + # offs_bk = offs_k + qkv_offset * NK + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + + a_mask = offs_am[:, None] < M + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = ( + b_ptr + + qkv_offset * N * NK // infearure_per_bits + + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + # g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :] + zeros_ptrs = ( + zeros_ptr + + qkv_offset * NK * N // gptq_group_size // infearure_per_bits + + (offs_bn[None, :] // infearure_per_bits) + ) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + g_ptrs = idx_ptr + offs_k + g_idx = tl.load(g_ptrs) + # tl.device_print("gidx, ", g_idx) + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = zeros + 1 + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros).to(tl.float16) * scales # Scale and shift + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + + if ADD_BIAS: + bias_mask = offs_bn < N + offs_bn += qkv_offset * N + bias_ptrs = bias_ptr + stride_cn * offs_bn + bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + accumulator += bias[None, :] + + if ACT_TYPE == 1: + accumulator = relu(accumulator) + elif ACT_TYPE == 2: + accumulator = gelu(accumulator) + elif ACT_TYPE == 3: + accumulator = silu(accumulator) + + if ADD_RESIDUAL: + residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + res = tl.load(residual_ptrs, mask=c_mask, other=0.0) + accumulator += res + + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def gptq_fused_linear_triton( + input, + qweight, + scales, + qzeros, + bias, + residual, + bits, + maxq, + gptq_group_size, + qkv_fused, + add_bias, + add_residual, + g_idx=None, + act_type=0, +): + # print("gptq fused ", qkv_fused, add_bias, add_residual) + assert input.is_cuda, "input is not in cuda" + assert qweight.is_cuda, "qweight is not in cuda" + assert scales.is_cuda, "scales is not in cuda" + assert qzeros.is_cuda, "qzeros is not in cuda" + + with torch.cuda.device(input.device): + if qkv_fused: + grid = lambda META: ( + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]) + * 3, + ) + output = torch.empty((input.shape[0] * 3, qweight.shape[1]), device=input.device, dtype=torch.float16) + else: + grid = lambda META: ( + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), + ) + output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) + # print("dtype, ", qweight.dtype, output.dtype, scales.dtype, qzeros.dtype, bias.dtype, residual.dtype) + if g_idx is None: + cai_gptq_matmul_248_kernel[grid]( + input, + qweight, + output, + scales, + qzeros, + bias, + residual, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + gptq_group_size, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + QKV_FUSED=qkv_fused, + ADD_BIAS=add_bias, + ADD_RESIDUAL=add_residual, + ACT_TYPE=act_type, + ) + else: + cai_gptq_idx_matmul_248_kernel[grid]( + input, + qweight, + output, + scales, + qzeros, + g_idx, + bias, + residual, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + gptq_group_size, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + QKV_FUSED=qkv_fused, + ADD_BIAS=add_bias, + ADD_RESIDUAL=add_residual, + ACT_TYPE=act_type, + ) + if qkv_fused: + return output.view(3, input.shape[0], qweight.shape[1]) + else: + return output diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 693528813..a285874d2 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -32,10 +32,13 @@ class ShardConfig: enable_fused_normalization: bool = False enable_flash_attention: bool = False enable_jit_fused: bool = False - enable_sequence_parallelism: bool = False - enable_sequence_overlap: bool = False enable_all_optimization: bool = False inference_only: bool = False + inference_gptq: bool = False + enable_sequence_parallelism: bool = False + enable_sequence_overlap: bool = False + # pipeline_parallel_size: int + # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] @property diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py new file mode 100644 index 000000000..43e118cc0 --- /dev/null +++ b/examples/inference/gptq_bloom.py @@ -0,0 +1,123 @@ +import argparse +import logging +import os +import time + +import torch +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig +from auto_gptq.nn_modules.qlinear import GeneralQuantLinear +from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM, LlamaTokenizer + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def print_perf_stats(latency_set, config, bs, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 # float16 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) + print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) + + +def bench_bloom(args): + + pretrained_model_dir = args.path + quantized_model_dir = args.quantized_path + max_batch_size = args.batch_size + max_input_len = args.input_len + max_output_len = args.output_len + + tokenizer = BloomTokenizerFast.from_pretrained(pretrained_model_dir) + tokenizer.pad_token = tokenizer.eos_token + + # load quantized model to the first GPU + model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, + device=torch.cuda.current_device(), + inject_fused_attention=False) + + model = model.half() + + model_config = model.config + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + + input_tokens = { + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), + "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') + } + + # init TPInferEngine and shard the original model + # To benchmark torch original, comment out the line of optimizing model + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, + inference_only=True, + inference_gptq=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) + + # prepare data for generation + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + input_tokens = { + "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)), + "attention_mask": torch.ones((max_batch_size, max_input_len)) + } + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) + # print(f" input_tokens[{t}].shape: {input_tokens[t].shape}") + + iters = 10 + times = [] + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s") + times.append((end - start) / (out_len - max_input_len)) + + print_perf_stats(times, model_config, max_batch_size) + + +def check_bloom(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + bench_bloom(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom(args): + spawn(check_bloom, args.tp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-p', '--path', type=str, help='Model path', required=True) + parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True) + parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') + parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') + parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') + parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + + args = parser.parse_args() + + test_bloom(args) diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py new file mode 100644 index 000000000..818ae0035 --- /dev/null +++ b/examples/inference/gptq_llama.py @@ -0,0 +1,135 @@ +import argparse +import logging +import os +import time + +import torch +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig +from auto_gptq.nn_modules.qlinear import GeneralQuantLinear +from torch import distributed as dist +from torch.profiler import ProfilerActivity, profile, record_function +from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, TextGenerationPipeline + +import colossalai +from colossalai.gptq import CaiQuantLinear +from colossalai.gptq.gptq_tp import replace_autogptq_linear +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / + self.config.head_dim_)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return + + +def print_perf_stats(latency_set, config, bs, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) + print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) + + +def run_llama_test(args): + pretrained_model_dir = args.path + quantized_model_dir = args.quantized_path + max_batch_size = args.batch_size + max_input_len = args.input_len + max_output_len = args.output_len + + tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) + tokenizer.pad_token_id = tokenizer.eos_token_id + + # load quantized model to the first GPU + model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, + device=torch.cuda.current_device(), + inject_fused_attention=False) + + init_to_get_rotary(model.model.model, base=10000) + + model_config = model.config + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, + inference_only=True, + inference_gptq=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) + + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + + input_tokens = { + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), + "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') + } + + iters = 10 + times = [] + + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s") + times.append((end - start) / (out_len - max_input_len)) + + print_perf_stats(times, model_config, max_batch_size) + + +def check_llama(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(args): + spawn(check_llama, args.tp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-p', '--path', type=str, help='Model path', required=True) + parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True) + parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') + parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') + parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') + parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + + args = parser.parse_args() + + test_llama(args) diff --git a/op_builder/gptq.py b/op_builder/gptq.py new file mode 100644 index 000000000..012cf0f8a --- /dev/null +++ b/op_builder/gptq.py @@ -0,0 +1,52 @@ +import os +import torch +import re + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + +class GPTQBuilder(Builder): + + NAME = "cu_gptq" + PREBUILT_IMPORT_PATH = "colossalai._C.cu_gptq" + + def __init__(self): + super().__init__(name=GPTQBuilder.NAME, + prebuilt_import_path=GPTQBuilder.PREBUILT_IMPORT_PATH) + + + def include_dirs(self): + ret = [self.csrc_abs_path("gptq"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) for fname in [ + 'gptq/linear_gptq.cpp', + 'gptq/column_remap.cu', + 'gptq/cuda_buffers.cu', + 'gptq/q4_matmul.cu', + 'gptq/q4_matrix.cu' + ] + ] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ['-v', + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK', "-lcublas", "-std=c++17" + ] + + + for arch in torch.cuda.get_arch_list(): + res = re.search(r'sm_(\d+)', arch) + if res: + arch_cap = res[1] + if int(arch_cap) >= 80: + extra_cuda_flags.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) + + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) \ No newline at end of file diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 53f0f958e..467f83610 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -18,3 +18,4 @@ SentencePiece ninja flash_attn==2.0.5 datasets +#auto-gptq now not support torch1.12 diff --git a/tests/test_gptq/test_gptq_linear.py b/tests/test_gptq/test_gptq_linear.py new file mode 100644 index 000000000..9b650aa78 --- /dev/null +++ b/tests/test_gptq/test_gptq_linear.py @@ -0,0 +1,150 @@ +import math +import time + +import numpy as np +import pytest +import torch +import torch.nn as nn +import transformers +from packaging import version + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +try: + from auto_gptq.modeling._utils import autogptq_post_init + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear + from exllama_kernels import prepare_buffers, set_tuning_params + + from colossalai.inference.quant.gptq import CaiQuantLinear + HAS_AUTO_GPTQ = True +except: + HAS_AUTO_GPTQ = False + print("please install AutoGPTQ from https://github.com/PanQiWei/AutoGPTQ") + +import warnings + +HAS_GPTQ_CUDA = False +try: + from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() + HAS_GPTQ_CUDA = True +except ImportError: + warnings.warn('CUDA gptq is not installed') + HAS_GPTQ_CUDA = False + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + +max_inner_outer_dim = 1 +max_input_len = 1 +max_dq_buffer_size = 1 +gptq_temp_dq_buffer = None +gptq_temp_state_buffer = None + + +def init_buffer(cai_linear, use_act_order=False): + global max_dq_buffer_size + global max_input_len + global max_dq_buffer_size + global max_inner_outer_dim + global gptq_temp_dq_buffer + global gptq_temp_state_buffer + + max_dq_buffer_size = max(max_dq_buffer_size, cai_linear.qweight.numel() * 8) + + if use_act_order: + max_inner_outer_dim = max(max_inner_outer_dim, cai_linear.infeatures, cai_linear.outfeatures) + + if use_act_order: + max_input_len = 4096 + # The temp_state buffer is required to reorder X in the act-order case. + # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim), + dtype=torch.float16, + device=torch.cuda.current_device()) + gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()) + + gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer) + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, + reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq") +def test_gptq_linear(): + + infeature = 1024 + outfeature = 1024 + group_size = 128 + wbits = 4 + + inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) + batch_inps = torch.randn(1, 16, infeature).to(torch.float16).to(torch.cuda.current_device()) + + device = torch.device("cuda:0") + + linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=wbits) + + linear = linear_class( + bits=4, + group_size=group_size, + infeatures=infeature, + outfeatures=outfeature, + bias=False, + ) + + torch.manual_seed(42) + + linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32) + linear.scales = linear.scales + 0.002 + + linear = linear.to(device) + + cai_linear = CaiQuantLinear(wbits, group_size, infeature, outfeature, True) + cai_linear.qweight.data.copy_(linear.qweight) + cai_linear.scales = cai_linear.scales + 0.002 + cai_linear = cai_linear.to(device) + + linear = autogptq_post_init(linear, use_act_order=False) + + max_inner_outer_dim = max(infeature, outfeature) + max_dq_buffer_size = linear.infeatures * linear.outfeatures + max_input_len = 2048 + buffers = { + "temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device), + "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device) + } + + prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"]) + + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + with torch.no_grad(): + gptq_out = linear(inps) + batch_gptq_out = linear(batch_inps) + torch.cuda.synchronize() + cai_out = cai_linear(inps) + torch.cuda.synchronize() + + batch_cai_out = cai_linear(batch_inps) + torch.cuda.synchronize() + + assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-01) + assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-01) + + +if __name__ == "__main__": + + test_gptq_linear()