mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn * opt tmp tensor * fix benchmark_llama * fix code style * fix None logic for output tensor * fix adapted to get_xine_cache * add comment * fix ci bugs * fix some codes * rm duplicated codes * rm duplicated codes * fix code style * add _get_dtype in config.py
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.layers.attention import PagedAttention
|
||||
from colossalai.inference.struct import BatchInfo
|
||||
from colossalai.kernel.triton import (
|
||||
@@ -50,7 +51,6 @@ def llama_causal_lm_forward(
|
||||
batch: BatchInfo = None,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
padding_id: int = None,
|
||||
):
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
hidden_states = llama_model_forward(
|
||||
@@ -58,7 +58,6 @@ def llama_causal_lm_forward(
|
||||
batch=batch,
|
||||
k_caches=k_caches,
|
||||
v_caches=v_caches,
|
||||
padding_id=padding_id,
|
||||
)
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
@@ -70,11 +69,10 @@ def llama_model_forward(
|
||||
batch: BatchInfo = None,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
padding_id: int = None,
|
||||
):
|
||||
input_ids = batch.get_batch_inputs()
|
||||
block_tables = batch.get_block_table_tensor()
|
||||
attention_mask = batch.get_attn_mask(padding_id)
|
||||
attention_mask = batch.get_attn_mask()
|
||||
|
||||
if attention_mask is not None:
|
||||
if HAS_TRITON:
|
||||
@@ -84,6 +82,7 @@ def llama_model_forward(
|
||||
else:
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
|
||||
batch_size, _ = input_ids.shape
|
||||
kv_seq_len = sequence_lengths.max().item()
|
||||
|
||||
if attention_mask is not None:
|
||||
@@ -102,7 +101,22 @@ def llama_model_forward(
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, hidden_states.dtype)
|
||||
# When testing, the performance of get_xine_cache is lower than that of get_cos_sin.
|
||||
# cos = get_xine_cache(sequence_lengths, self._cos_cached, batch.is_prompts)
|
||||
# sin = get_xine_cache(sequence_lengths, self._sin_cached, batch.is_prompts)
|
||||
# cos_sin = (cos, sin)
|
||||
|
||||
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, batch.dtype)
|
||||
|
||||
if batch.is_prompts:
|
||||
output_tensor = torch.zeros(
|
||||
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
else:
|
||||
output_tensor = torch.zeros(
|
||||
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
sm_scale = 1.0 / (batch.head_dim**0.5)
|
||||
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
hidden_states = decoder_layer(
|
||||
@@ -116,6 +130,9 @@ def llama_model_forward(
|
||||
attention_mask=attention_mask,
|
||||
kv_seq_len=kv_seq_len,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
output_tensor=output_tensor,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
@@ -131,10 +148,13 @@ def llama_decoder_layer_forward(
|
||||
k_cache: torch.Tensor = None,
|
||||
v_cache: torch.Tensor = None,
|
||||
is_prompts: bool = True,
|
||||
sequence_lengths: int = None,
|
||||
sequence_lengths: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
kv_seq_len: int = 0,
|
||||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
fd_inter_tensor: FDIntermTensors = None,
|
||||
output_tensor: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
|
||||
@@ -151,6 +171,9 @@ def llama_decoder_layer_forward(
|
||||
attention_mask=attention_mask,
|
||||
kv_seq_len=kv_seq_len,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
output_tensor=output_tensor,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
@@ -178,6 +201,9 @@ def llama_attn_forward(
|
||||
attention_mask: torch.Tensor = None,
|
||||
kv_seq_len: int = 0,
|
||||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
fd_inter_tensor: FDIntermTensors = None,
|
||||
output_tensor: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@@ -206,7 +232,17 @@ def llama_attn_forward(
|
||||
|
||||
if is_prompts:
|
||||
attn_output = context_attention_unpadded(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
v=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
context_lengths=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
output=output_tensor,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
if attention_mask is not None:
|
||||
attn_output = pad_input(attn_output, indices, bsz, q_len)
|
||||
@@ -214,7 +250,17 @@ def llama_attn_forward(
|
||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
attn_output = flash_decoding_attention(
|
||||
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
attn_output = attn_output.squeeze(1)
|
||||
else:
|
||||
@@ -285,6 +331,16 @@ def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_
|
||||
|
||||
@torch.no_grad()
|
||||
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
|
||||
"""
|
||||
Get cos and sin for the cache, and return nopad format.
|
||||
Args:
|
||||
lengths: shape(num_seqs,), stores lenghth of each sequence.
|
||||
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
|
||||
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
|
||||
is_prompts: bool, mark if in prefill mode.
|
||||
dtype: The data type of this inference process.
|
||||
"""
|
||||
|
||||
if is_prompts:
|
||||
index_arrays = [torch.arange(length) for length in lengths]
|
||||
else:
|
||||
|
Reference in New Issue
Block a user