mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 04:55:25 +00:00
[inference] Refactor inference architecture (#5057)
* [inference] support only TP (#4998) * support only tp * enable tp * add support for bloom (#5008) * [refactor] refactor gptq and smoothquant llama (#5012) * refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcuda * [Inference Refactor] Merge chatglm2 with pp and tp (#5023) merge chatglm with pp and tp * [Refactor] remove useless inference code (#5022) * remove useless code * fix quant model * fix test import bug * mv original inference legacy * fix chatglm2 * [Refactor] refactor policy search and quant type controlling in inference (#5035) * [Refactor] refactor policy search and quant type controling in inference * [inference] update readme (#5051) * update readme * update readme * fix architecture * fix table * fix table * [inference] udpate example (#5053) * udpate example * fix run.sh * fix rebase bug * fix some errors * update readme * add some features * update interface * update readme * update benchmark * add requirements-infer --------- Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
1
colossalai/inference/quant/__init__.py
Normal file
1
colossalai/inference/quant/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .smoothquant.models.llama import SmoothLlamaForCausalLM
|
@@ -2,3 +2,4 @@ from .cai_gptq import HAS_AUTO_GPTQ
|
||||
|
||||
if HAS_AUTO_GPTQ:
|
||||
from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear
|
||||
from .gptq_manager import GPTQManager
|
||||
|
61
colossalai/inference/quant/gptq/gptq_manager.py
Normal file
61
colossalai/inference/quant/gptq/gptq_manager.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import torch
|
||||
|
||||
|
||||
class GPTQManager:
|
||||
def __init__(self, quant_config, max_input_len: int = 1):
|
||||
self.max_dq_buffer_size = 1
|
||||
self.max_inner_outer_dim = 1
|
||||
self.bits = quant_config.bits
|
||||
self.use_act_order = quant_config.desc_act
|
||||
self.max_input_len = 1
|
||||
self.gptq_temp_state_buffer = None
|
||||
self.gptq_temp_dq_buffer = None
|
||||
self.quant_config = quant_config
|
||||
|
||||
def post_init_gptq_buffer(self, model: torch.nn.Module) -> None:
|
||||
from .cai_gptq import CaiQuantLinear
|
||||
|
||||
HAS_GPTQ_CUDA = False
|
||||
try:
|
||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||
|
||||
gptq_cuda = GPTQBuilder().load()
|
||||
HAS_GPTQ_CUDA = True
|
||||
except ImportError:
|
||||
warnings.warn("CUDA gptq is not installed")
|
||||
HAS_GPTQ_CUDA = False
|
||||
|
||||
for name, submodule in model.named_modules():
|
||||
if isinstance(submodule, CaiQuantLinear):
|
||||
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
|
||||
|
||||
if self.use_act_order:
|
||||
self.max_inner_outer_dim = max(
|
||||
self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures
|
||||
)
|
||||
self.bits = submodule.bits
|
||||
if not (HAS_GPTQ_CUDA and self.bits == 4):
|
||||
return
|
||||
|
||||
max_input_len = 1
|
||||
if self.use_act_order:
|
||||
max_input_len = self.max_input_len
|
||||
# The temp_state buffer is required to reorder X in the act-order case.
|
||||
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
self.gptq_temp_state_buffer = torch.zeros(
|
||||
(max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
|
||||
)
|
||||
self.gptq_temp_dq_buffer = torch.zeros(
|
||||
(1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
gptq_cuda.prepare_buffers(
|
||||
torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer
|
||||
)
|
||||
# Using the default from exllama repo here.
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
matmul_no_half2 = False
|
||||
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
torch.cuda.empty_cache()
|
@@ -4,9 +4,7 @@ try:
|
||||
HAS_TORCH_INT = True
|
||||
except ImportError:
|
||||
HAS_TORCH_INT = False
|
||||
raise ImportError(
|
||||
"Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int"
|
||||
)
|
||||
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
|
||||
|
||||
if HAS_TORCH_INT:
|
||||
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP
|
||||
|
@@ -9,7 +9,6 @@ from functools import partial
|
||||
from os.path import isdir, isfile, join
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import accelerate
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -21,8 +20,16 @@ from transformers.modeling_utils import no_init_weights
|
||||
from transformers.utils.generic import ContextManagers
|
||||
from transformers.utils.hub import PushToHubMixin, cached_file
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager
|
||||
from colossalai.inference.kv_cache.batch_infer_state import BatchInferState, MemoryManager
|
||||
|
||||
try:
|
||||
import accelerate
|
||||
|
||||
HAS_ACCELERATE = True
|
||||
except ImportError:
|
||||
HAS_ACCELERATE = False
|
||||
print("accelerate is not installed.")
|
||||
|
||||
|
||||
SUPPORTED_MODELS = ["llama"]
|
||||
|
||||
|
@@ -1,17 +1,25 @@
|
||||
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
|
||||
|
||||
import torch
|
||||
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
|
||||
from torch_int.functional.quantization import quantize_per_tensor_absmax
|
||||
|
||||
try:
|
||||
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
|
||||
from torch_int.functional.quantization import quantize_per_tensor_absmax
|
||||
|
||||
HAS_TORCH_INT = True
|
||||
except ImportError:
|
||||
HAS_TORCH_INT = False
|
||||
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
|
||||
|
||||
|
||||
try:
|
||||
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
|
||||
|
||||
smoothquant_cuda = SmoothquantBuilder().load()
|
||||
HAS_SMOOTHQUANT_CUDA = True
|
||||
except ImportError:
|
||||
except:
|
||||
HAS_SMOOTHQUANT_CUDA = False
|
||||
raise ImportError("CUDA smoothquant linear is not installed")
|
||||
print("CUDA smoothquant linear is not installed")
|
||||
|
||||
|
||||
class W8A8BFP32O32LinearSiLU(torch.nn.Module):
|
||||
@@ -138,21 +146,23 @@ class W8A8BFP32OFP32Linear(torch.nn.Module):
|
||||
)
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False),
|
||||
torch.zeros((1, self.out_features), dtype=torch.float32, requires_grad=False),
|
||||
)
|
||||
self.register_buffer("a", torch.tensor(alpha))
|
||||
|
||||
def _apply(self, fn):
|
||||
# prevent the bias from being converted to half
|
||||
super()._apply(fn)
|
||||
self.bias = self.bias.to(torch.float32)
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(torch.float32)
|
||||
return self
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.weight = self.weight.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(torch.float32)
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(torch.float32)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
|
@@ -8,7 +8,6 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
@@ -18,12 +17,11 @@ from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaMLP,
|
||||
LlamaRotaryEmbedding,
|
||||
repeat_kv,
|
||||
rotate_half,
|
||||
)
|
||||
from transformers.utils import add_start_docstrings_to_model_forward
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.inference.kv_cache.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton import (
|
||||
copy_kv_cache_to_dest,
|
||||
int8_rotary_embedding_fwd,
|
||||
@@ -31,10 +29,31 @@ from colossalai.kernel.triton import (
|
||||
smooth_token_attention_fwd,
|
||||
)
|
||||
|
||||
try:
|
||||
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
|
||||
|
||||
HAS_TORCH_INT = True
|
||||
except ImportError:
|
||||
HAS_TORCH_INT = False
|
||||
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
|
||||
|
||||
|
||||
from .base_model import BaseSmoothForCausalLM
|
||||
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class LLamaSmoothquantAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -116,7 +135,6 @@ class LLamaSmoothquantAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
rotary_emb: Tuple[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
@@ -131,8 +149,7 @@ class LLamaSmoothquantAttention(nn.Module):
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
cos = rotary_emb[0]
|
||||
sin = rotary_emb[1]
|
||||
cos, sin = infer_state.position_cos, infer_state.position_sin
|
||||
|
||||
int8_rotary_embedding_fwd(
|
||||
query_states.view(-1, self.num_heads, self.head_dim),
|
||||
@@ -348,7 +365,6 @@ class LlamaSmoothquantDecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
rotary_emb: Tuple[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
@@ -378,7 +394,6 @@ class LlamaSmoothquantDecoderLayer(nn.Module):
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
rotary_emb=rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
@@ -650,15 +665,15 @@ def llama_model_forward(
|
||||
raise NotImplementedError("not implement gradient_checkpointing and training options ")
|
||||
|
||||
if past_key_values_length == 0:
|
||||
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
|
||||
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
|
||||
position_ids.view(-1).shape[0], -1
|
||||
)
|
||||
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
|
||||
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
|
||||
position_ids.view(-1).shape[0], -1
|
||||
)
|
||||
else:
|
||||
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
@@ -673,7 +688,6 @@ def llama_model_forward(
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
rotary_emb=(position_cos, position_sin),
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
|
264
colossalai/inference/quant/smoothquant/models/parallel_linear.py
Normal file
264
colossalai/inference/quant/smoothquant/models/parallel_linear.py
Normal file
@@ -0,0 +1,264 @@
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import ParallelModule
|
||||
|
||||
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
|
||||
|
||||
|
||||
def split_row_copy(smooth_linear, para_linear, tp_size=1, tp_rank=0, split_num=1):
|
||||
qweights = smooth_linear.weight.split(smooth_linear.out_features // split_num, dim=0)
|
||||
if smooth_linear.bias is not None:
|
||||
bias = smooth_linear.bias.split(smooth_linear.out_features // split_num, dim=0)
|
||||
|
||||
smooth_split_out_features = para_linear.out_features // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
para_linear.weight[i * smooth_split_out_features : (i + 1) * smooth_split_out_features, :] = qweights[i][
|
||||
tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features, :
|
||||
]
|
||||
|
||||
if para_linear.bias is not None:
|
||||
para_linear.bias[:, i * smooth_split_out_features : (i + 1) * smooth_split_out_features] = bias[i][
|
||||
:, tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features
|
||||
]
|
||||
|
||||
|
||||
def split_column_copy(smooth_linear, para_linear, tp_rank=0, split_num=1):
|
||||
qweights = smooth_linear.weight.split(smooth_linear.in_features // split_num, dim=-1)
|
||||
|
||||
smooth_split_in_features = para_linear.in_features // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
para_linear.weight[:, i * smooth_split_in_features : (i + 1) * smooth_split_in_features] = qweights[i][
|
||||
:, tp_rank * smooth_split_in_features : (tp_rank + 1) * smooth_split_in_features
|
||||
]
|
||||
|
||||
if smooth_linear.bias is not None:
|
||||
para_linear.bias.copy_(smooth_linear.bias)
|
||||
|
||||
|
||||
class RowW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__(in_features, out_features, alpha, beta)
|
||||
self.process_group = None
|
||||
self.tp_size = 1
|
||||
self.tp_rank = 0
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
out_features = module.out_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
||||
if out_features < tp_size:
|
||||
return module
|
||||
|
||||
if out_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = RowW8A8B8O8Linear(module.in_features, module.out_features // tp_size)
|
||||
linear_1d.tp_size = tp_size
|
||||
linear_1d.tp_rank = tp_rank
|
||||
linear_1d.process_group = process_group
|
||||
linear_1d.a = module.a.clone().detach()
|
||||
linear_1d.b = module.b.clone().detach()
|
||||
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
return linear_1d
|
||||
|
||||
|
||||
class ColW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__(in_features, out_features, alpha, beta)
|
||||
self.process_group = None
|
||||
self.tp_size = 1
|
||||
self.tp_rank = 0
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
||||
if in_features < tp_size:
|
||||
return module
|
||||
|
||||
if in_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = ColW8A8B8O8Linear(module.in_features // tp_size, module.out_features)
|
||||
linear_1d.tp_size = tp_size
|
||||
linear_1d.tp_rank = tp_rank
|
||||
linear_1d.process_group = process_group
|
||||
linear_1d.a = torch.tensor(module.a)
|
||||
linear_1d.b = torch.tensor(module.b)
|
||||
|
||||
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
if linear_1d.bias is not None:
|
||||
linear_1d.bias = linear_1d.bias // tp_size
|
||||
|
||||
return linear_1d
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
output = super().forward(x)
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
||||
return output
|
||||
|
||||
|
||||
class RowW8A8BFP32O32LinearSiLU(W8A8BFP32O32LinearSiLU, ParallelModule):
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__(in_features, out_features, alpha, beta)
|
||||
self.process_group = None
|
||||
self.tp_size = 1
|
||||
self.tp_rank = 0
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
out_features = module.out_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
||||
if out_features < tp_size:
|
||||
return module
|
||||
|
||||
if out_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = RowW8A8BFP32O32LinearSiLU(module.in_features, module.out_features // tp_size)
|
||||
linear_1d.tp_size = tp_size
|
||||
linear_1d.tp_rank = tp_rank
|
||||
linear_1d.process_group = process_group
|
||||
linear_1d.a = module.a.clone().detach()
|
||||
|
||||
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
return linear_1d
|
||||
|
||||
|
||||
class RowW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__(in_features, out_features, alpha, beta)
|
||||
self.process_group = None
|
||||
self.tp_size = 1
|
||||
self.tp_rank = 0
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
out_features = module.out_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
||||
if out_features < tp_size:
|
||||
return module
|
||||
|
||||
if out_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = RowW8A8BFP32OFP32Linear(module.in_features, module.out_features // tp_size)
|
||||
linear_1d.tp_size = tp_size
|
||||
linear_1d.tp_rank = tp_rank
|
||||
linear_1d.process_group = process_group
|
||||
linear_1d.a = module.a.clone().detach()
|
||||
|
||||
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
return linear_1d
|
||||
|
||||
|
||||
class ColW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__(in_features, out_features, alpha, beta)
|
||||
self.process_group = None
|
||||
self.tp_size = 1
|
||||
self.tp_rank = 0
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
||||
if in_features < tp_size:
|
||||
return module
|
||||
|
||||
if in_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = ColW8A8BFP32OFP32Linear(module.in_features // tp_size, module.out_features)
|
||||
linear_1d.tp_size = tp_size
|
||||
linear_1d.tp_rank = tp_rank
|
||||
linear_1d.process_group = process_group
|
||||
linear_1d.a = module.a.clone().detach()
|
||||
|
||||
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
if linear_1d.bias is not None:
|
||||
linear_1d.bias = linear_1d.bias / tp_size
|
||||
|
||||
return linear_1d
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
output = super().forward(x)
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
||||
return output
|
Reference in New Issue
Block a user