mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Inference] Adapt Baichuan2-13B TP (#5659)
* adapt to baichuan2 13B * add baichuan2 13B TP * update baichuan tp logic * rm unused code * Fix TP logic * fix alibi slopes tp logic * rm nn.Module * Polished the code. * change BAICHUAN_MODEL_NAME_OR_PATH * Modified the logic for loading Baichuan weights. * fix typos
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
|
||||
import itertools
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
@@ -16,6 +19,18 @@ from colossalai.kernel.triton import (
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
use_flash_attn2 = True
|
||||
except ImportError:
|
||||
use_flash_attn2 = False
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
@@ -78,14 +93,18 @@ def baichuan_rmsnorm_forward(
|
||||
return rms_layernorm(hidden_states, self.weight.data, eps, norm_output, residual)
|
||||
|
||||
|
||||
class NopadBaichuanAttention(nn.Module):
|
||||
class NopadBaichuanAttention(ParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
attn_qproj_w: torch.Tensor = None,
|
||||
attn_kproj_w: torch.Tensor = None,
|
||||
attn_vproj_w: torch.Tensor = None,
|
||||
attn_oproj_w: torch.Tensor = None,
|
||||
attn_oproj: ParallelModule = None,
|
||||
num_heads: int = None,
|
||||
hidden_size: int = None,
|
||||
process_group: ProcessGroup = None,
|
||||
helper_layout: Layout = None,
|
||||
):
|
||||
"""This layer will replace the BaichuanAttention.
|
||||
|
||||
@@ -94,26 +113,35 @@ class NopadBaichuanAttention(nn.Module):
|
||||
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
|
||||
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
|
||||
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
|
||||
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
|
||||
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None.
|
||||
"""
|
||||
super().__init__()
|
||||
self.o_proj_weight = attn_oproj_w
|
||||
ParallelModule.__init__(self)
|
||||
self.o_proj = attn_oproj
|
||||
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.process_group = process_group
|
||||
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
|
||||
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
|
||||
|
||||
self.helper_layout = helper_layout
|
||||
|
||||
self.alibi_slopes = None
|
||||
self.use_alibi_attn = False
|
||||
if self.hidden_size == 5120:
|
||||
# Used for Baichuan13B
|
||||
if config.hidden_size == 5120:
|
||||
slopes_start = self.process_group.rank() * num_heads
|
||||
self.use_alibi_attn = True
|
||||
self.alibi_slopes = get_alibi_slopes(self.num_heads, device=attn_qproj_w.device)
|
||||
|
||||
qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
|
||||
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
|
||||
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[
|
||||
slopes_start : slopes_start + num_heads
|
||||
].contiguous()
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, *args, **kwargs) -> "NopadBaichuanAttention":
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> "NopadBaichuanAttention":
|
||||
"""Used for initialize the weight of NopadBaichuanAttention by origin BaichuanAttention.
|
||||
|
||||
Args:
|
||||
@@ -121,24 +149,76 @@ class NopadBaichuanAttention(nn.Module):
|
||||
"""
|
||||
|
||||
config = module.config
|
||||
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((module.hidden_size, 3, -1)).transpose(0, 1)
|
||||
|
||||
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((3, module.hidden_size, module.hidden_size))
|
||||
attn_qproj_w = q_proj_w
|
||||
attn_kproj_w = k_proj_w
|
||||
attn_vproj_w = v_proj_w
|
||||
attn_oproj = module.o_proj
|
||||
|
||||
attn_qproj_w = q_proj_w.transpose(0, 1)
|
||||
attn_kproj_w = k_proj_w.transpose(0, 1)
|
||||
attn_vproj_w = v_proj_w.transpose(0, 1)
|
||||
attn_oproj_w = module.o_proj.weight.transpose(0, 1)
|
||||
helper_layout = (
|
||||
module.W_pack.weight.dist_layout
|
||||
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
||||
|
||||
attn_layer = NopadBaichuanAttention(
|
||||
config=config,
|
||||
attn_qproj_w=attn_qproj_w,
|
||||
attn_kproj_w=attn_kproj_w,
|
||||
attn_vproj_w=attn_vproj_w,
|
||||
attn_oproj_w=attn_oproj_w,
|
||||
attn_oproj=attn_oproj,
|
||||
num_heads=module.num_heads,
|
||||
hidden_size=module.hidden_size,
|
||||
process_group=process_group,
|
||||
helper_layout=helper_layout,
|
||||
)
|
||||
|
||||
return attn_layer
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
key = "qkv_weight"
|
||||
qkv_w = state_dict[prefix + "W_pack.weight"]
|
||||
|
||||
in_features = qkv_w.size(1)
|
||||
out_features = qkv_w.size(0) // 3
|
||||
|
||||
qkv_w.data = qkv_w.view((3, out_features, -1)).transpose(0, 1).reshape(out_features, in_features * 3)
|
||||
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
qkv_w = distribute_tensor(qkv_w, device_mesh, sharding_spec)
|
||||
|
||||
qkv_w = qkv_w.transpose(0, 1).reshape(3, in_features, -1)
|
||||
input_param = nn.Parameter(
|
||||
qkv_w
|
||||
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
|
||||
param = local_state[key]
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
|
||||
strict = False # to avoid unexpected_keys
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -292,56 +372,38 @@ class NopadBaichuanAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = torch.mm(attn_output, self.o_proj_weight)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"
|
||||
|
||||
|
||||
# NOTE This will cause difference as out length increases.
|
||||
class NopadBaichuanMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
mlp_gproj_w: torch.Tensor = None,
|
||||
mlp_uproj_w: torch.Tensor = None,
|
||||
mlp_dproj_w: torch.Tensor = None,
|
||||
):
|
||||
"""This layer will replace the BaichuanAttention.
|
||||
|
||||
Args:
|
||||
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
|
||||
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
|
||||
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
|
||||
"""
|
||||
super().__init__()
|
||||
self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
|
||||
self.down_proj_weight = mlp_dproj_w
|
||||
|
||||
class NopadBaichuanMLP(NopadLlamaMLP):
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
"""Used for initialize the weight of NopadBaichuanMLP by origin MLP(Baichuan).
|
||||
|
||||
Args:
|
||||
module (nn.Module): The origin MLP(Baichuan) layer.
|
||||
"""
|
||||
|
||||
mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
|
||||
mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
|
||||
mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
|
||||
mlp_gproj_w = module.gate_proj.weight
|
||||
assert is_distributed_tensor(
|
||||
module.gate_proj.weight
|
||||
), "gate_proj.weight must be dtensor so we could get the layout of the weight"
|
||||
mlp_uproj_w = module.up_proj.weight
|
||||
mlp_dproj = module.down_proj
|
||||
|
||||
mlp_layer = NopadBaichuanMLP(
|
||||
config=None,
|
||||
mlp_gproj_w=mlp_gproj_w,
|
||||
mlp_uproj_w=mlp_uproj_w,
|
||||
mlp_dproj_w=mlp_dproj_w,
|
||||
mlp_dproj=mlp_dproj,
|
||||
process_group=process_group,
|
||||
)
|
||||
|
||||
return mlp_layer
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
"""
|
||||
hidden_states = hidden_states.expand(2, -1, -1)
|
||||
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
|
||||
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
|
||||
return torch.mm(act_out, self.down_proj_weight)
|
||||
|
Reference in New Issue
Block a user