add context_attention_unpadded

This commit is contained in:
yuehuayingxueluo
2024-01-03 18:50:26 +08:00
committed by FrankLeeeee
parent 07b5283b6a
commit 02c1bf8b2a
5 changed files with 37 additions and 29 deletions

View File

@@ -5,6 +5,7 @@ import torch
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
from colossalai.inference.struct import BatchInfo
from colossalai.kernel.triton import context_attention_unpadded
def rotate_half(x):
@@ -53,7 +54,6 @@ def llama_causal_lm_forward(
v_caches=v_caches,
)
logits = self.lm_head(hidden_states)
return logits
@@ -157,15 +157,17 @@ def llama_attn_forward(
key_states = key_states.view(-1, self.num_heads, self.head_dim)
value_states = value_states.view(-1, self.num_heads, self.head_dim)
# TODO: The code below will be uncommented after the development of attention-related kernel is completed.
# memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size, sequence_lengths)
# if is_prompts:
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
# else:
# attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
# decoding_attention(query_states, k_cache, v_cache, block_tables, sequence_lengths, attn_output, block_tables.shape[1], block_size)
_, _, _, block_size = k_cache.shape
# NOTE: context_attention_unpadded is unsed for testing accuracy and we can only use aligned inputs.
# The code below will be uncommented after the development of attention-related kernel is completed.
if is_prompts:
attn_output = context_attention_unpadded(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
)
# else:
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
attn_output = query_states
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)