mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[Inference]Adapt to baichuan2 13B (#5614)
* adapt to baichuan2 13B * adapt to baichuan2 13B * change BAICHUAN_MODEL_NAME_OR_PATH * fix test_decoding_attn.py * Modifications based on review comments. * change BAICHUAN_MODEL_NAME_OR_PATH * mv attn mask processes to test flash decoding * mv get_alibi_slopes baichuan modeling * fix bugs in test_baichuan.py
This commit is contained in:
@@ -1,19 +1,83 @@
|
||||
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaAttention
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
copy_k_to_blocked_cache,
|
||||
decoding_fused_rotary_embedding,
|
||||
flash_decoding_attention,
|
||||
rms_layernorm,
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
use_flash_attn2 = True
|
||||
except ImportError:
|
||||
use_flash_attn2 = False
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
|
||||
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
|
||||
slopes = torch.pow(base, powers)
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
|
||||
def baichuan_rmsnorm_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
norm_output: torch.Tensor,
|
||||
residual: torch.Tensor = None,
|
||||
use_cuda_kernel: bool = True,
|
||||
):
|
||||
# Used to address the issue of inconsistent epsilon variable names in baichuan2 7b and 13b.
|
||||
if hasattr(self, "variance_epsilon"):
|
||||
eps = self.variance_epsilon
|
||||
elif hasattr(self, "epsilon"):
|
||||
eps = self.epsilon
|
||||
else:
|
||||
TypeError(
|
||||
"Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'."
|
||||
)
|
||||
|
||||
if use_cuda_kernel:
|
||||
if residual is not None:
|
||||
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, eps)
|
||||
return hidden_states, residual
|
||||
|
||||
if norm_output is None:
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, eps)
|
||||
return norm_output, hidden_states
|
||||
else:
|
||||
return rms_layernorm(hidden_states, self.weight.data, eps, norm_output, residual)
|
||||
|
||||
|
||||
class NopadBaichuanAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -39,9 +103,11 @@ class NopadBaichuanAttention(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
|
||||
# Used to adapt llama_base_attn_forward
|
||||
self.num_key_value_heads = self.num_heads
|
||||
self.alibi_slopes = None
|
||||
self.use_alibi_attn = False
|
||||
if self.hidden_size == 5120:
|
||||
self.use_alibi_attn = True
|
||||
self.alibi_slopes = get_alibi_slopes(self.num_heads, device=attn_qproj_w.device)
|
||||
|
||||
qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
|
||||
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
|
||||
@@ -112,26 +178,124 @@ class NopadBaichuanAttention(nn.Module):
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
|
||||
return NopadLlamaAttention.forward(
|
||||
self,
|
||||
hidden_states=hidden_states,
|
||||
block_tables=block_tables,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
sequence_lengths=sequence_lengths,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
is_prompts=is_prompts,
|
||||
is_verifier=is_verifier,
|
||||
tokens_to_verify=tokens_to_verify,
|
||||
kv_seq_len=kv_seq_len,
|
||||
output_tensor=output_tensor,
|
||||
sm_scale=sm_scale,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
cu_seqlens=cu_seqlens,
|
||||
high_precision=high_precision,
|
||||
token_nums = hidden_states.size(0)
|
||||
# fused qkv
|
||||
hidden_states = hidden_states.expand(3, -1, -1)
|
||||
query_states, key_states, value_states = (
|
||||
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
|
||||
)
|
||||
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
if (
|
||||
not is_verifier
|
||||
and use_cuda_kernel
|
||||
and query_states.dtype != torch.float32
|
||||
and use_flash_attn2
|
||||
and not self.use_alibi_attn
|
||||
):
|
||||
# flash attn 2 currently only supports FP16/BF16.
|
||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
|
||||
)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=kv_seq_len,
|
||||
max_seqlen_k=kv_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=sm_scale,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
else:
|
||||
if not self.use_alibi_attn:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
attn_output = context_attention_unpadded(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
v=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
context_lengths=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
output=output_tensor,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
else:
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
|
||||
if use_cuda_kernel:
|
||||
if not self.use_alibi_attn:
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
high_precision,
|
||||
)
|
||||
else:
|
||||
inference_ops.decode_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
|
||||
)
|
||||
else:
|
||||
if not is_verifier and not self.use_alibi_attn:
|
||||
decoding_fused_rotary_embedding(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
sequence_lengths,
|
||||
)
|
||||
else:
|
||||
if not self.use_alibi_attn:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
copy_k_to_blocked_cache(
|
||||
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
copy_k_to_blocked_cache(
|
||||
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sm_scale=sm_scale,
|
||||
q_len=q_len,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = torch.mm(attn_output, self.o_proj_weight)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
# NOTE This will cause difference as out length increases.
|
||||
class NopadBaichuanMLP(nn.Module):
|
||||
|
Reference in New Issue
Block a user