mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[Inference]Support FP16/BF16 Flash Attention 2 And Add high_precision Flag To Rotary Embedding (#5461)
* Support FP16/BF16 Flash Attention 2 * fix bugs in test_kv_cache_memcpy.py * add context_kv_cache_memcpy_kernel.cu * rm typename MT * add tail process * add high_precision * add high_precision to config.py * rm unused code * change the comment for the high_precision parameter * update test_rotary_embdding_unpad.py * fix vector_copy_utils.h * add comment for self.high_precision when using float32
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaConfig,
|
||||
@@ -30,24 +31,28 @@ inference_ops = InferenceOpsLoader().load()
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
HAS_TRITON = True
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
use_flash_attn2 = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")
|
||||
use_flash_attn2 = False
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
|
||||
|
||||
def llama_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
batch: BatchBucket = None,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
batch: BatchBucket,
|
||||
k_caches: List[torch.Tensor],
|
||||
v_caches: List[torch.Tensor],
|
||||
high_precision: bool = False,
|
||||
):
|
||||
"""This function will replace the forward function of LlamaForCausalLM.
|
||||
|
||||
Args:
|
||||
batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None.
|
||||
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
||||
batch (BatchInfo): It stores the necessary input information for this inference.
|
||||
k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache.
|
||||
v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache.
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
@@ -56,6 +61,7 @@ def llama_causal_lm_forward(
|
||||
batch=batch,
|
||||
k_caches=k_caches,
|
||||
v_caches=v_caches,
|
||||
high_precision=high_precision,
|
||||
)
|
||||
logits = torch.mm(hidden_states, self.lm_head.weight)
|
||||
return logits
|
||||
@@ -63,16 +69,18 @@ def llama_causal_lm_forward(
|
||||
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
batch: BatchBucket = None,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
batch: BatchBucket,
|
||||
k_caches: List[torch.Tensor],
|
||||
v_caches: List[torch.Tensor],
|
||||
high_precision: bool = False,
|
||||
):
|
||||
"""This function will replace the forward function of LlamaModel.
|
||||
|
||||
Args:
|
||||
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
|
||||
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
||||
batch (BatchInfo): It stores the necessary input information for this inference.
|
||||
k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache.
|
||||
v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache.
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
input_ids = batch.get_1D_inputs()
|
||||
block_tables = batch.get_block_table_tensor()
|
||||
@@ -86,6 +94,11 @@ def llama_model_forward(
|
||||
if batch_size >= 32 and kv_seq_len > 512:
|
||||
use_cuda_kernel = False
|
||||
|
||||
if use_cuda_kernel and batch.dtype != torch.float32 and use_flash_attn2:
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
else:
|
||||
cu_seqlens = None
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)
|
||||
@@ -110,15 +123,17 @@ def llama_model_forward(
|
||||
block_tables=block_tables,
|
||||
k_cache=k_caches[layer_id],
|
||||
v_cache=v_caches[layer_id],
|
||||
is_prompts=batch.is_prompts,
|
||||
sequence_lengths=sequence_lengths,
|
||||
kv_seq_len=kv_seq_len,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
is_prompts=batch.is_prompts,
|
||||
kv_seq_len=kv_seq_len,
|
||||
output_tensor=output_tensor,
|
||||
norm_output=norm_output,
|
||||
sm_scale=sm_scale,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
cu_seqlens=cu_seqlens,
|
||||
high_precision=high_precision,
|
||||
)
|
||||
|
||||
if batch.is_prompts:
|
||||
@@ -135,38 +150,42 @@ def llama_decoder_layer_forward(
|
||||
self: LlamaDecoderLayer,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
block_tables: torch.Tensor = None,
|
||||
k_cache: torch.Tensor = None,
|
||||
v_cache: torch.Tensor = None,
|
||||
block_tables: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
cos_sin: Tuple[torch.Tensor],
|
||||
fd_inter_tensor: FDIntermTensors,
|
||||
is_prompts: bool = True,
|
||||
sequence_lengths: torch.Tensor = None,
|
||||
kv_seq_len: int = 0,
|
||||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
fd_inter_tensor: FDIntermTensors = None,
|
||||
output_tensor: torch.Tensor = None,
|
||||
norm_output: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
use_cuda_kernel: bool = True,
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
high_precision: bool = False,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""This function will replace the forward function of LlamaDecoderLayer.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
|
||||
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||
storing mapping of token_position_id -> block_id. Defaults to None.
|
||||
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||
storing mapping of token_position_id -> block_id.
|
||||
k_cache (torch.Tensor): It holds the GPU memory for the key cache.
|
||||
v_cache (torch.Tensor): It holds the GPU memory for the key cache.
|
||||
sequence_lengths (torch.Tensor): Holding the sequence length of each sequence.
|
||||
cos_sin (Tuple[torch.Tensor]): Holding cos and sin.
|
||||
fd_inter_tensor (FDIntermTensors): Holding tensors used for
|
||||
storing intermediate values in flash-decoding.
|
||||
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
|
||||
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
|
||||
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
||||
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
|
||||
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
|
||||
storing intermediate values in flash-decoding. Defaults to None.
|
||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
|
||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
||||
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||
@@ -176,14 +195,16 @@ def llama_decoder_layer_forward(
|
||||
block_tables=block_tables,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
is_prompts=is_prompts,
|
||||
sequence_lengths=sequence_lengths,
|
||||
kv_seq_len=kv_seq_len,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
is_prompts=is_prompts,
|
||||
kv_seq_len=kv_seq_len,
|
||||
output_tensor=output_tensor,
|
||||
sm_scale=sm_scale,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
cu_seqlens=cu_seqlens,
|
||||
high_precision=high_precision,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -277,43 +298,48 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
block_tables: torch.Tensor = None,
|
||||
k_cache: torch.Tensor = None,
|
||||
v_cache: torch.Tensor = None,
|
||||
block_tables: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
cos_sin: Tuple[torch.Tensor],
|
||||
fd_inter_tensor: FDIntermTensors,
|
||||
is_prompts: bool = True,
|
||||
sequence_lengths: torch.Tensor = None,
|
||||
kv_seq_len: int = 0,
|
||||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
fd_inter_tensor: FDIntermTensors = None,
|
||||
output_tensor: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
use_cuda_kernel: bool = True,
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
high_precision: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||
storing mapping of token_position_id -> block_id. Defaults to None.
|
||||
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
|
||||
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
|
||||
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
||||
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
|
||||
block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||
storing mapping of token_position_id -> block_id.
|
||||
k_cache (torch.Tensor): It holds the GPU memory for the key cache.
|
||||
v_cache (torch.Tensor): It holds the GPU memory for the key cache.
|
||||
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence.
|
||||
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin.
|
||||
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
|
||||
storing intermediate values in flash-decoding. Defaults to None.
|
||||
storing intermediate values in flash-decoding.
|
||||
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
|
||||
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
||||
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
|
||||
token_nums = hidden_states.size(0)
|
||||
|
||||
if self.num_heads != self.num_key_value_heads:
|
||||
query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim)
|
||||
key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
|
||||
value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
|
||||
else:
|
||||
# fused qkv
|
||||
token_nums = hidden_states.size(0)
|
||||
hidden_states = hidden_states.expand(3, -1, -1)
|
||||
query_states, key_states, value_states = (
|
||||
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
|
||||
@@ -322,23 +348,41 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
if use_cuda_kernel:
|
||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
||||
# flash attn 2 currently only supports FP16/BF16.
|
||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
|
||||
)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=kv_seq_len,
|
||||
max_seqlen_k=kv_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=sm_scale,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
else:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
attn_output = context_attention_unpadded(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
v=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
context_lengths=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
output=output_tensor,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
attn_output = context_attention_unpadded(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
v=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
context_lengths=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
output=output_tensor,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
else:
|
||||
if use_cuda_kernel:
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
@@ -351,6 +395,7 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
high_precision,
|
||||
)
|
||||
else:
|
||||
decoding_fused_rotary_embedding(
|
||||
@@ -436,6 +481,5 @@ class NopadLlamaMLP(LlamaMLP):
|
||||
"""
|
||||
hidden_states = hidden_states.expand(2, -1, -1)
|
||||
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
|
||||
act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True)
|
||||
tmp_out = act_out * gate_up_proj_out[1]
|
||||
return torch.mm(tmp_out, self.down_proj_weight)
|
||||
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
|
||||
return torch.mm(act_out, self.down_proj_weight)
|
||||
|
Reference in New Issue
Block a user