mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-29 22:41:15 +00:00
[Inference] Adapt Baichuan2-13B TP (#5659)
* adapt to baichuan2 13B * add baichuan2 13B TP * update baichuan tp logic * rm unused code * Fix TP logic * fix alibi slopes tp logic * rm nn.Module * Polished the code. * change BAICHUAN_MODEL_NAME_OR_PATH * Modified the logic for loading Baichuan weights. * fix typos
This commit is contained in:
parent
808ee6e4ad
commit
5f00002e43
@ -26,7 +26,7 @@ _ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
_DEFAULT_PROMPT_TEMPLATES = {
|
||||
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
|
||||
"baichuan": "<reserved_106>{input_text}<reserved_107>",
|
||||
"baichuan": " <reserved_106> {input_text} <reserved_107> ",
|
||||
"vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ",
|
||||
}
|
||||
|
||||
|
@ -112,11 +112,23 @@ class InferenceEngine:
|
||||
model_policy (Policy): the policy to replace the model
|
||||
"""
|
||||
|
||||
casuallm = None
|
||||
if isinstance(model_or_path, str):
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||
arch = getattr(hf_config, "architectures")[0]
|
||||
model = _supported_models[arch](hf_config)
|
||||
if arch in _supported_models.keys():
|
||||
casuallm = _supported_models[arch](hf_config)
|
||||
if isinstance(casuallm, AutoModelForCausalLM):
|
||||
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half().cuda()
|
||||
)
|
||||
else:
|
||||
model = _supported_models[arch](hf_config)
|
||||
else:
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
||||
@ -164,7 +176,7 @@ class InferenceEngine:
|
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if isinstance(model_or_path, str):
|
||||
if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
|
||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
|
||||
cpt_io = InferCheckpoint_io()
|
||||
|
43
colossalai/inference/modeling/layers/baichuan_tp_linear.py
Normal file
43
colossalai/inference/modeling/layers/baichuan_tp_linear.py
Normal file
@ -0,0 +1,43 @@
|
||||
from typing import List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.shardformer.layer import Linear1D_Col
|
||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||
|
||||
|
||||
class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
module.in_features = module.weight.size(1)
|
||||
module.out_features = module.weight.size(0)
|
||||
module.bias = None
|
||||
module.weight.data = nn.functional.normalize(module.weight)
|
||||
|
||||
return Linear1D_Col.from_native_module(
|
||||
module,
|
||||
process_group,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class BaichuanWpackLinear1D_Col(Linear1D_Col):
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
in_features = module.in_features * 3
|
||||
out_features = module.out_features // 3
|
||||
module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features)
|
||||
module.bias = None
|
||||
|
||||
return Linear1D_Col.from_native_module(
|
||||
module,
|
||||
process_group,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
@ -1,11 +1,14 @@
|
||||
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
|
||||
import itertools
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
@ -16,6 +19,18 @@ from colossalai.kernel.triton import (
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
use_flash_attn2 = True
|
||||
except ImportError:
|
||||
use_flash_attn2 = False
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
@ -78,14 +93,18 @@ def baichuan_rmsnorm_forward(
|
||||
return rms_layernorm(hidden_states, self.weight.data, eps, norm_output, residual)
|
||||
|
||||
|
||||
class NopadBaichuanAttention(nn.Module):
|
||||
class NopadBaichuanAttention(ParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
attn_qproj_w: torch.Tensor = None,
|
||||
attn_kproj_w: torch.Tensor = None,
|
||||
attn_vproj_w: torch.Tensor = None,
|
||||
attn_oproj_w: torch.Tensor = None,
|
||||
attn_oproj: ParallelModule = None,
|
||||
num_heads: int = None,
|
||||
hidden_size: int = None,
|
||||
process_group: ProcessGroup = None,
|
||||
helper_layout: Layout = None,
|
||||
):
|
||||
"""This layer will replace the BaichuanAttention.
|
||||
|
||||
@ -94,26 +113,35 @@ class NopadBaichuanAttention(nn.Module):
|
||||
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
|
||||
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
|
||||
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
|
||||
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
|
||||
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None.
|
||||
"""
|
||||
super().__init__()
|
||||
self.o_proj_weight = attn_oproj_w
|
||||
ParallelModule.__init__(self)
|
||||
self.o_proj = attn_oproj
|
||||
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.process_group = process_group
|
||||
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
|
||||
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
|
||||
|
||||
self.helper_layout = helper_layout
|
||||
|
||||
self.alibi_slopes = None
|
||||
self.use_alibi_attn = False
|
||||
if self.hidden_size == 5120:
|
||||
# Used for Baichuan13B
|
||||
if config.hidden_size == 5120:
|
||||
slopes_start = self.process_group.rank() * num_heads
|
||||
self.use_alibi_attn = True
|
||||
self.alibi_slopes = get_alibi_slopes(self.num_heads, device=attn_qproj_w.device)
|
||||
|
||||
qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
|
||||
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
|
||||
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[
|
||||
slopes_start : slopes_start + num_heads
|
||||
].contiguous()
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, *args, **kwargs) -> "NopadBaichuanAttention":
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> "NopadBaichuanAttention":
|
||||
"""Used for initialize the weight of NopadBaichuanAttention by origin BaichuanAttention.
|
||||
|
||||
Args:
|
||||
@ -121,24 +149,76 @@ class NopadBaichuanAttention(nn.Module):
|
||||
"""
|
||||
|
||||
config = module.config
|
||||
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((module.hidden_size, 3, -1)).transpose(0, 1)
|
||||
|
||||
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((3, module.hidden_size, module.hidden_size))
|
||||
attn_qproj_w = q_proj_w
|
||||
attn_kproj_w = k_proj_w
|
||||
attn_vproj_w = v_proj_w
|
||||
attn_oproj = module.o_proj
|
||||
|
||||
attn_qproj_w = q_proj_w.transpose(0, 1)
|
||||
attn_kproj_w = k_proj_w.transpose(0, 1)
|
||||
attn_vproj_w = v_proj_w.transpose(0, 1)
|
||||
attn_oproj_w = module.o_proj.weight.transpose(0, 1)
|
||||
helper_layout = (
|
||||
module.W_pack.weight.dist_layout
|
||||
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
||||
|
||||
attn_layer = NopadBaichuanAttention(
|
||||
config=config,
|
||||
attn_qproj_w=attn_qproj_w,
|
||||
attn_kproj_w=attn_kproj_w,
|
||||
attn_vproj_w=attn_vproj_w,
|
||||
attn_oproj_w=attn_oproj_w,
|
||||
attn_oproj=attn_oproj,
|
||||
num_heads=module.num_heads,
|
||||
hidden_size=module.hidden_size,
|
||||
process_group=process_group,
|
||||
helper_layout=helper_layout,
|
||||
)
|
||||
|
||||
return attn_layer
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
key = "qkv_weight"
|
||||
qkv_w = state_dict[prefix + "W_pack.weight"]
|
||||
|
||||
in_features = qkv_w.size(1)
|
||||
out_features = qkv_w.size(0) // 3
|
||||
|
||||
qkv_w.data = qkv_w.view((3, out_features, -1)).transpose(0, 1).reshape(out_features, in_features * 3)
|
||||
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
qkv_w = distribute_tensor(qkv_w, device_mesh, sharding_spec)
|
||||
|
||||
qkv_w = qkv_w.transpose(0, 1).reshape(3, in_features, -1)
|
||||
input_param = nn.Parameter(
|
||||
qkv_w
|
||||
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
|
||||
param = local_state[key]
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
|
||||
strict = False # to avoid unexpected_keys
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -292,56 +372,38 @@ class NopadBaichuanAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = torch.mm(attn_output, self.o_proj_weight)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"
|
||||
|
||||
|
||||
# NOTE This will cause difference as out length increases.
|
||||
class NopadBaichuanMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
mlp_gproj_w: torch.Tensor = None,
|
||||
mlp_uproj_w: torch.Tensor = None,
|
||||
mlp_dproj_w: torch.Tensor = None,
|
||||
):
|
||||
"""This layer will replace the BaichuanAttention.
|
||||
|
||||
Args:
|
||||
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
|
||||
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
|
||||
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
|
||||
"""
|
||||
super().__init__()
|
||||
self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
|
||||
self.down_proj_weight = mlp_dproj_w
|
||||
|
||||
class NopadBaichuanMLP(NopadLlamaMLP):
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
"""Used for initialize the weight of NopadBaichuanMLP by origin MLP(Baichuan).
|
||||
|
||||
Args:
|
||||
module (nn.Module): The origin MLP(Baichuan) layer.
|
||||
"""
|
||||
|
||||
mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
|
||||
mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
|
||||
mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
|
||||
mlp_gproj_w = module.gate_proj.weight
|
||||
assert is_distributed_tensor(
|
||||
module.gate_proj.weight
|
||||
), "gate_proj.weight must be dtensor so we could get the layout of the weight"
|
||||
mlp_uproj_w = module.up_proj.weight
|
||||
mlp_dproj = module.down_proj
|
||||
|
||||
mlp_layer = NopadBaichuanMLP(
|
||||
config=None,
|
||||
mlp_gproj_w=mlp_gproj_w,
|
||||
mlp_uproj_w=mlp_uproj_w,
|
||||
mlp_dproj_w=mlp_dproj_w,
|
||||
mlp_dproj=mlp_dproj,
|
||||
process_group=process_group,
|
||||
)
|
||||
|
||||
return mlp_layer
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
"""
|
||||
hidden_states = hidden_states.expand(2, -1, -1)
|
||||
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
|
||||
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
|
||||
return torch.mm(act_out, self.down_proj_weight)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch.nn as nn
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.inference.modeling.layers.baichuan_tp_linear import (
|
||||
BaichuanLMHeadLinear1D_Col,
|
||||
BaichuanWpackLinear1D_Col,
|
||||
)
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import (
|
||||
NopadBaichuanAttention,
|
||||
NopadBaichuanMLP,
|
||||
@ -12,6 +13,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
|
||||
llama_model_forward,
|
||||
)
|
||||
from colossalai.inference.utils import init_to_get_rotary
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
@ -23,39 +25,72 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
decoder_attribute_replacement = {
|
||||
"lm_head.weight": Parameter(nn.functional.normalize(self.model.lm_head.weight), requires_grad=False),
|
||||
}
|
||||
policy["BaichuanForCausalLM"] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
|
||||
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
||||
)
|
||||
else:
|
||||
decoder_attribute_replacement = None
|
||||
|
||||
# used for relpacing Baichuan 7B/13B decoder layer
|
||||
for layer_name in ["DecoderLayer", "BaichuanLayer"]:
|
||||
policy[layer_name] = ModulePolicyDescription(
|
||||
# used for Baichuan 7B and 13B for baichuan DecoderLayer
|
||||
for DecoderLayer in ["DecoderLayer", "BaichuanLayer"]:
|
||||
policy[DecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp",
|
||||
target_module=NopadBaichuanMLP,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.W_pack",
|
||||
target_module=BaichuanWpackLinear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn",
|
||||
target_module=NopadBaichuanAttention,
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=layer_name
|
||||
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=DecoderLayer
|
||||
)
|
||||
|
||||
policy["BaichuanForCausalLM"] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=BaichuanLMHeadLinear1D_Col, kwargs={"gather_output": True}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": llama_causal_lm_forward}, policy=policy, target_key="BaichuanForCausalLM"
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": llama_model_forward}, policy=policy, target_key="BaichuanModel"
|
||||
)
|
||||
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": baichuan_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
|
||||
)
|
||||
|
@ -4,26 +4,29 @@ import random
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.multiprocessing import Manager
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
|
||||
BAICHUAN_MODEL_NAME_OR_PATH = "/home/data/models/Baichuan2-13B-Base"
|
||||
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base"
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None):
|
||||
def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, use_fast=False, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, trust_remote_code=True).half().cuda()
|
||||
@ -34,7 +37,6 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
|
||||
]
|
||||
|
||||
output_len = 38
|
||||
do_sample = do_sample
|
||||
|
||||
if do_sample:
|
||||
top_p = 0.5
|
||||
@ -45,9 +47,12 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
|
||||
|
||||
if use_engine:
|
||||
inference_config = InferenceConfig(
|
||||
max_output_len=output_len, prompt_template=prompt_template, use_cuda_kernel=use_cuda_kernel
|
||||
max_output_len=output_len,
|
||||
prompt_template=prompt_template,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
tp_size=dist.get_world_size(),
|
||||
)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
|
||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
@ -70,31 +75,54 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
|
||||
)
|
||||
outputs = model.generate(inputs, generation_config=generation_config)
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@parameterize("prompt_template", [None, "baichuan"])
|
||||
@parameterize("do_sample", [True, False])
|
||||
@parameterize("use_cuda_kernel", [True, False])
|
||||
def check_output_consistency(prompt_template, do_sample, use_cuda_kernel):
|
||||
cai_outputs = check_inference_engine(
|
||||
use_engine=True, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template
|
||||
)
|
||||
transformer_outputs = check_inference_engine(
|
||||
use_engine=False, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template
|
||||
)
|
||||
def run_engine(world_size, **kwargs):
|
||||
manager = Manager()
|
||||
result_list = manager.list([-1] * world_size) # Create a shared list
|
||||
|
||||
for s1, s2 in zip(cai_outputs, transformer_outputs):
|
||||
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
|
||||
|
||||
# clear singleton flash decoding tensors
|
||||
FDIntermTensors._instances = {}
|
||||
spawn(run_dist, world_size, func_to_run=check_inference_engine, ret=result_list, **kwargs)
|
||||
return result_list[0]
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_output_consistency()
|
||||
|
||||
if ret:
|
||||
ret[rank] = func_to_run(**kwargs)
|
||||
else:
|
||||
func_to_run(**kwargs)
|
||||
|
||||
|
||||
# NOTE(caidi) If do_sample is set to True or use_cuda_kernel is set to False, the inference result will be different from that of the transformer.
|
||||
@parameterize("prompt_template", [None, "baichuan"])
|
||||
@parameterize("do_sample", [False])
|
||||
@parameterize("use_cuda_kernel", [True])
|
||||
def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
|
||||
kwargs1 = {
|
||||
"use_engine": True,
|
||||
"prompt_template": prompt_template,
|
||||
"do_sample": do_sample,
|
||||
"policy": NoPaddingBaichuanModelInferPolicy(),
|
||||
"use_cuda_kernel": use_cuda_kernel,
|
||||
}
|
||||
|
||||
kwargs2 = {
|
||||
"use_engine": False,
|
||||
"prompt_template": prompt_template,
|
||||
"do_sample": do_sample,
|
||||
"policy": None,
|
||||
"use_cuda_kernel": use_cuda_kernel,
|
||||
}
|
||||
|
||||
colossal_tp_1_output = run_engine(1, **kwargs1)
|
||||
colossal_tp_2_output = run_engine(2, **kwargs1)
|
||||
transformer_tp_1_output = run_engine(1, **kwargs2)
|
||||
|
||||
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
|
||||
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
|
||||
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@ -104,7 +132,7 @@ def run_dist(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_inference_engine():
|
||||
spawn(run_dist, 1)
|
||||
test_tp_engine()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -193,6 +193,7 @@ def test_vllm_flash_decoding_attention(
|
||||
max_seq_len_across_batch = kv_seq_lengths.max().item()
|
||||
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
|
||||
sm_scale = 1.0 / (HEAD_SIZE**0.5)
|
||||
kv_scale = 1.0
|
||||
|
||||
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
|
||||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
|
||||
@ -250,6 +251,7 @@ def test_vllm_flash_decoding_attention(
|
||||
max_seq_len_across_batch,
|
||||
alibi_slopes,
|
||||
"auto",
|
||||
kv_scale,
|
||||
)
|
||||
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user