mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[Inference]Adapted to the triton attn kernels (#5264)
* adapted to the triton attn kernels * fix pad input * adapted to copy_kv_to_blocked_cache * fix ci test * update kv memcpy * remove print
This commit is contained in:
@@ -2,19 +2,23 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
|
||||
|
||||
from colossalai.inference.modeling.layers.attention import PagedAttention
|
||||
from colossalai.inference.struct import BatchInfo
|
||||
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_fwd
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
@@ -35,6 +39,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def llama_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
batch: BatchInfo = None,
|
||||
@@ -54,6 +59,7 @@ def llama_causal_lm_forward(
|
||||
return logits
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
batch: BatchInfo = None,
|
||||
@@ -63,15 +69,30 @@ def llama_model_forward(
|
||||
):
|
||||
input_ids = batch.get_batch_inputs()
|
||||
block_tables = batch.get_block_table_tensor()
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
|
||||
attention_mask = batch.get_attn_mask(padding_id)
|
||||
|
||||
if batch.is_prompts:
|
||||
# Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer.
|
||||
position_ids = generate_padding_position_id(attention_mask)
|
||||
if attention_mask is not None:
|
||||
# TODO After the nopad version is implemented, we will use the following code to get sequence_lengths.
|
||||
# sequence_lengths = batch.get_sequence_lengths()
|
||||
sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
else:
|
||||
position_ids = (attention_mask.sum(dim=-1) - 1).reshape(-1, 1)
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
|
||||
kv_seq_len = sequence_lengths.max().item()
|
||||
|
||||
if attention_mask is not None:
|
||||
if batch.is_prompts:
|
||||
# Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer.
|
||||
position_ids = generate_padding_position_id(attention_mask)
|
||||
else:
|
||||
position_ids = (attention_mask.sum(dim=-1) - 1).reshape(-1, 1)
|
||||
else:
|
||||
if batch.is_prompts:
|
||||
position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=batch.device)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
else:
|
||||
position_ids = torch.arange(kv_seq_len - 1, kv_seq_len, dtype=torch.long, device=batch.device)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
@@ -85,13 +106,14 @@ def llama_model_forward(
|
||||
is_prompts=batch.is_prompts,
|
||||
sequence_lengths=sequence_lengths,
|
||||
attention_mask=attention_mask,
|
||||
kv_seq_len=kv_seq_len,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def llama_decoder_layer_forward(
|
||||
self: LlamaDecoderLayer,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -102,6 +124,7 @@ def llama_decoder_layer_forward(
|
||||
is_prompts: bool = True,
|
||||
sequence_lengths: int = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
kv_seq_len: int = 0,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
|
||||
@@ -116,6 +139,7 @@ def llama_decoder_layer_forward(
|
||||
is_prompts=is_prompts,
|
||||
sequence_lengths=sequence_lengths,
|
||||
attention_mask=attention_mask,
|
||||
kv_seq_len=kv_seq_len,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
@@ -130,6 +154,7 @@ def llama_decoder_layer_forward(
|
||||
|
||||
|
||||
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
|
||||
@torch.no_grad()
|
||||
def llama_attn_forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -140,6 +165,7 @@ def llama_attn_forward(
|
||||
is_prompts: bool = True,
|
||||
sequence_lengths: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
kv_seq_len: int = 0,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@@ -147,26 +173,44 @@ def llama_attn_forward(
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = sequence_lengths[0].item()
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
_, _, _, block_size = k_cache.shape
|
||||
|
||||
if is_prompts:
|
||||
attn_output = PagedAttention.pad_context_forward(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||
)
|
||||
if HAS_TRITON:
|
||||
if attention_mask is not None:
|
||||
query_states, key_states, value_states, indices = unpading_input(
|
||||
query_states, key_states, value_states, attention_mask
|
||||
)
|
||||
else:
|
||||
query_states = query_states.view(-1, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
||||
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
||||
|
||||
attn_output = context_attention_unpadded(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
||||
)
|
||||
if attention_mask is not None:
|
||||
attn_output = pad_input(attn_output, indices, bsz, q_len)
|
||||
else:
|
||||
attn_output = PagedAttention.pad_context_forward(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||
)
|
||||
else:
|
||||
attn_output = PagedAttention.pad_decoding_forward(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||
)
|
||||
if HAS_TRITON:
|
||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
attn_output = flash_decoding_fwd(query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
|
||||
else:
|
||||
attn_output = PagedAttention.pad_decoding_forward(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
@@ -175,7 +219,18 @@ def llama_attn_forward(
|
||||
return attn_output
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
return position_ids
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor):
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape
|
||||
q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
||||
k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
||||
v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
||||
return (q, k, v, indices)
|
||||
|
Reference in New Issue
Block a user