[Inference] Adapt Baichuan2-13B TP (#5659)

* adapt to baichuan2 13B

* add baichuan2 13B TP

* update baichuan tp logic

* rm unused code

* Fix TP logic

* fix alibi slopes tp logic

* rm nn.Module

* Polished the code.

* change BAICHUAN_MODEL_NAME_OR_PATH

* Modified the logic for loading Baichuan weights.

* fix typos
This commit is contained in:
yuehuayingxueluo 2024-04-30 15:47:07 +08:00 committed by GitHub
parent 808ee6e4ad
commit 5f00002e43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 280 additions and 98 deletions

View File

@ -112,11 +112,23 @@ class InferenceEngine:
model_policy (Policy): the policy to replace the model
"""
casuallm = None
if isinstance(model_or_path, str):
try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
arch = getattr(hf_config, "architectures")[0]
if arch in _supported_models.keys():
casuallm = _supported_models[arch](hf_config)
if isinstance(casuallm, AutoModelForCausalLM):
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
model = (
AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half().cuda()
)
else:
model = _supported_models[arch](hf_config)
else:
raise ValueError(f"Model {arch} is not supported.")
except Exception as e:
self.logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
@ -164,7 +176,7 @@ class InferenceEngine:
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
if isinstance(model_or_path, str):
if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
from colossalai.inference.core.plugin import InferCheckpoint_io
cpt_io = InferCheckpoint_io()

View File

@ -0,0 +1,43 @@
from typing import List, Union
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.shardformer.layer import Linear1D_Col
from colossalai.shardformer.layer.parallel_module import ParallelModule
class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
module.in_features = module.weight.size(1)
module.out_features = module.weight.size(0)
module.bias = None
module.weight.data = nn.functional.normalize(module.weight)
return Linear1D_Col.from_native_module(
module,
process_group,
*args,
**kwargs,
)
class BaichuanWpackLinear1D_Col(Linear1D_Col):
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
in_features = module.in_features * 3
out_features = module.out_features // 3
module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features)
module.bias = None
return Linear1D_Col.from_native_module(
module,
process_group,
*args,
**kwargs,
)

View File

@ -1,11 +1,14 @@
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
import itertools
import math
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
@ -16,6 +19,18 @@ from colossalai.kernel.triton import (
rotary_embedding,
)
from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
logger = get_dist_logger(__name__)
try:
from flash_attn import flash_attn_varlen_func
use_flash_attn2 = True
except ImportError:
use_flash_attn2 = False
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
logger = get_dist_logger(__name__)
@ -78,14 +93,18 @@ def baichuan_rmsnorm_forward(
return rms_layernorm(hidden_states, self.weight.data, eps, norm_output, residual)
class NopadBaichuanAttention(nn.Module):
class NopadBaichuanAttention(ParallelModule):
def __init__(
self,
config,
attn_qproj_w: torch.Tensor = None,
attn_kproj_w: torch.Tensor = None,
attn_vproj_w: torch.Tensor = None,
attn_oproj_w: torch.Tensor = None,
attn_oproj: ParallelModule = None,
num_heads: int = None,
hidden_size: int = None,
process_group: ProcessGroup = None,
helper_layout: Layout = None,
):
"""This layer will replace the BaichuanAttention.
@ -94,26 +113,35 @@ class NopadBaichuanAttention(nn.Module):
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None.
"""
super().__init__()
self.o_proj_weight = attn_oproj_w
ParallelModule.__init__(self)
self.o_proj = attn_oproj
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_heads = num_heads
self.hidden_size = hidden_size
self.head_dim = self.hidden_size // self.num_heads
self.process_group = process_group
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
self.helper_layout = helper_layout
self.alibi_slopes = None
self.use_alibi_attn = False
if self.hidden_size == 5120:
# Used for Baichuan13B
if config.hidden_size == 5120:
slopes_start = self.process_group.rank() * num_heads
self.use_alibi_attn = True
self.alibi_slopes = get_alibi_slopes(self.num_heads, device=attn_qproj_w.device)
qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[
slopes_start : slopes_start + num_heads
].contiguous()
@staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> "NopadBaichuanAttention":
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> "NopadBaichuanAttention":
"""Used for initialize the weight of NopadBaichuanAttention by origin BaichuanAttention.
Args:
@ -121,24 +149,76 @@ class NopadBaichuanAttention(nn.Module):
"""
config = module.config
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((module.hidden_size, 3, -1)).transpose(0, 1)
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((3, module.hidden_size, module.hidden_size))
attn_qproj_w = q_proj_w
attn_kproj_w = k_proj_w
attn_vproj_w = v_proj_w
attn_oproj = module.o_proj
attn_qproj_w = q_proj_w.transpose(0, 1)
attn_kproj_w = k_proj_w.transpose(0, 1)
attn_vproj_w = v_proj_w.transpose(0, 1)
attn_oproj_w = module.o_proj.weight.transpose(0, 1)
helper_layout = (
module.W_pack.weight.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
attn_layer = NopadBaichuanAttention(
config=config,
attn_qproj_w=attn_qproj_w,
attn_kproj_w=attn_kproj_w,
attn_vproj_w=attn_vproj_w,
attn_oproj_w=attn_oproj_w,
attn_oproj=attn_oproj,
num_heads=module.num_heads,
hidden_size=module.hidden_size,
process_group=process_group,
helper_layout=helper_layout,
)
return attn_layer
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
key = "qkv_weight"
qkv_w = state_dict[prefix + "W_pack.weight"]
in_features = qkv_w.size(1)
out_features = qkv_w.size(0) // 3
qkv_w.data = qkv_w.view((3, out_features, -1)).transpose(0, 1).reshape(out_features, in_features * 3)
device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
qkv_w = distribute_tensor(qkv_w, device_mesh, sharding_spec)
qkv_w = qkv_w.transpose(0, 1).reshape(3, in_features, -1)
input_param = nn.Parameter(
qkv_w
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
param = local_state[key]
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
strict = False # to avoid unexpected_keys
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(
self,
hidden_states: torch.Tensor,
@ -292,56 +372,38 @@ class NopadBaichuanAttention(nn.Module):
)
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj_weight)
attn_output = self.o_proj(attn_output)
return attn_output
def extra_repr(self) -> str:
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"
# NOTE This will cause difference as out length increases.
class NopadBaichuanMLP(nn.Module):
def __init__(
self,
mlp_gproj_w: torch.Tensor = None,
mlp_uproj_w: torch.Tensor = None,
mlp_dproj_w: torch.Tensor = None,
):
"""This layer will replace the BaichuanAttention.
Args:
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
"""
super().__init__()
self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
self.down_proj_weight = mlp_dproj_w
class NopadBaichuanMLP(NopadLlamaMLP):
@staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
"""Used for initialize the weight of NopadBaichuanMLP by origin MLP(Baichuan).
Args:
module (nn.Module): The origin MLP(Baichuan) layer.
"""
mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
mlp_gproj_w = module.gate_proj.weight
assert is_distributed_tensor(
module.gate_proj.weight
), "gate_proj.weight must be dtensor so we could get the layout of the weight"
mlp_uproj_w = module.up_proj.weight
mlp_dproj = module.down_proj
mlp_layer = NopadBaichuanMLP(
config=None,
mlp_gproj_w=mlp_gproj_w,
mlp_uproj_w=mlp_uproj_w,
mlp_dproj_w=mlp_dproj_w,
mlp_dproj=mlp_dproj,
process_group=process_group,
)
return mlp_layer
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
"""
hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
return torch.mm(act_out, self.down_proj_weight)

View File

@ -1,6 +1,7 @@
import torch.nn as nn
from torch.nn import Parameter
from colossalai.inference.modeling.layers.baichuan_tp_linear import (
BaichuanLMHeadLinear1D_Col,
BaichuanWpackLinear1D_Col,
)
from colossalai.inference.modeling.models.nopadding_baichuan import (
NopadBaichuanAttention,
NopadBaichuanMLP,
@ -12,6 +13,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
llama_model_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
@ -23,30 +25,64 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = {
"lm_head.weight": Parameter(nn.functional.normalize(self.model.lm_head.weight), requires_grad=False),
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}
policy["BaichuanForCausalLM"] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
)
else:
decoder_attribute_replacement = None
# used for relpacing Baichuan 7B/13B decoder layer
for layer_name in ["DecoderLayer", "BaichuanLayer"]:
policy[layer_name] = ModulePolicyDescription(
# used for Baichuan 7B and 13B for baichuan DecoderLayer
for DecoderLayer in ["DecoderLayer", "BaichuanLayer"]:
policy[DecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp",
target_module=NopadBaichuanMLP,
),
SubModuleReplacementDescription(
suffix="self_attn.W_pack",
target_module=BaichuanWpackLinear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="self_attn",
target_module=NopadBaichuanAttention,
),
]
],
)
self.append_or_create_method_replacement(
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=layer_name
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=DecoderLayer
)
policy["BaichuanForCausalLM"] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=BaichuanLMHeadLinear1D_Col, kwargs={"gather_output": True}
)
],
)
self.append_or_create_method_replacement(
@ -55,7 +91,6 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
self.append_or_create_method_replacement(
description={"forward": llama_model_forward}, policy=policy, target_key="BaichuanModel"
)
self.append_or_create_method_replacement(
description={"forward": baichuan_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
)

View File

@ -4,26 +4,29 @@ import random
import numpy as np
import pytest
import torch
import torch.distributed as dist
from torch.multiprocessing import Manager
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
BAICHUAN_MODEL_NAME_OR_PATH = "/home/data/models/Baichuan2-13B-Base"
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base"
def setup_seed(seed):
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None):
def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, trust_remote_code=True).half().cuda()
@ -34,7 +37,6 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
]
output_len = 38
do_sample = do_sample
if do_sample:
top_p = 0.5
@ -45,9 +47,12 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
if use_engine:
inference_config = InferenceConfig(
max_output_len=output_len, prompt_template=prompt_template, use_cuda_kernel=use_cuda_kernel
max_output_len=output_len,
prompt_template=prompt_template,
use_cuda_kernel=use_cuda_kernel,
tp_size=dist.get_world_size(),
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
@ -70,31 +75,54 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return outputs
@parameterize("prompt_template", [None, "baichuan"])
@parameterize("do_sample", [True, False])
@parameterize("use_cuda_kernel", [True, False])
def check_output_consistency(prompt_template, do_sample, use_cuda_kernel):
cai_outputs = check_inference_engine(
use_engine=True, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template
)
transformer_outputs = check_inference_engine(
use_engine=False, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template
)
def run_engine(world_size, **kwargs):
manager = Manager()
result_list = manager.list([-1] * world_size) # Create a shared list
for s1, s2 in zip(cai_outputs, transformer_outputs):
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
# clear singleton flash decoding tensors
FDIntermTensors._instances = {}
spawn(run_dist, world_size, func_to_run=check_inference_engine, ret=result_list, **kwargs)
return result_list[0]
def run_dist(rank, world_size, port):
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency()
if ret:
ret[rank] = func_to_run(**kwargs)
else:
func_to_run(**kwargs)
# NOTE(caidi) If do_sample is set to True or use_cuda_kernel is set to False, the inference result will be different from that of the transformer.
@parameterize("prompt_template", [None, "baichuan"])
@parameterize("do_sample", [False])
@parameterize("use_cuda_kernel", [True])
def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
kwargs1 = {
"use_engine": True,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": NoPaddingBaichuanModelInferPolicy(),
"use_cuda_kernel": use_cuda_kernel,
}
kwargs2 = {
"use_engine": False,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": None,
"use_cuda_kernel": use_cuda_kernel,
}
colossal_tp_1_output = run_engine(1, **kwargs1)
colossal_tp_2_output = run_engine(2, **kwargs1)
transformer_tp_1_output = run_engine(1, **kwargs2)
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
@pytest.mark.skipif(
@ -104,7 +132,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_inference_engine():
spawn(run_dist, 1)
test_tp_engine()
if __name__ == "__main__":

View File

@ -193,6 +193,7 @@ def test_vllm_flash_decoding_attention(
max_seq_len_across_batch = kv_seq_lengths.max().item()
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
kv_scale = 1.0
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
@ -250,6 +251,7 @@ def test_vllm_flash_decoding_attention(
max_seq_len_across_batch,
alibi_slopes,
"auto",
kv_scale,
)
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)