[inference] Adapted to Rotary Embedding and RMS Norm (#5283)

* adapted to rotary_embedding

* adapted to nopad rms norm

* fix bugs in benchmark

* fix flash_decoding.py
This commit is contained in:
yuehuayingxueluo
2024-01-22 10:55:34 +08:00
committed by GitHub
parent 6e487e7d3c
commit bfff9254ac
5 changed files with 140 additions and 43 deletions

View File

@@ -6,7 +6,12 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
from colossalai.inference.modeling.layers.attention import PagedAttention
from colossalai.inference.struct import BatchInfo
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_attention
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_kv_to_blocked_cache,
flash_decoding_attention,
rotary_embedding,
)
from colossalai.logging import get_dist_logger
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
@@ -72,9 +77,10 @@ def llama_model_forward(
attention_mask = batch.get_attn_mask(padding_id)
if attention_mask is not None:
# TODO After the nopad version is implemented, we will use the following code to get sequence_lengths.
# sequence_lengths = batch.get_sequence_lengths()
sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32)
if HAS_TRITON:
sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32)
else:
sequence_lengths = batch.get_sequence_lengths()
else:
sequence_lengths = batch.get_sequence_lengths()
@@ -96,6 +102,8 @@ def llama_model_forward(
hidden_states = self.embed_tokens(input_ids)
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, hidden_states.dtype)
for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
@@ -107,6 +115,7 @@ def llama_model_forward(
sequence_lengths=sequence_lengths,
attention_mask=attention_mask,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
)
hidden_states = self.norm(hidden_states)
@@ -125,6 +134,7 @@ def llama_decoder_layer_forward(
sequence_lengths: int = None,
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
@@ -140,6 +150,7 @@ def llama_decoder_layer_forward(
sequence_lengths=sequence_lengths,
attention_mask=attention_mask,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
)
hidden_states = residual + hidden_states
@@ -166,27 +177,16 @@ def llama_attn_forward(
sequence_lengths: torch.Tensor = None,
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
kv_seq_len = max(sequence_lengths).item()
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
_, _, _, block_size = k_cache.shape
if is_prompts:
if HAS_TRITON:
if HAS_TRITON:
if is_prompts:
if attention_mask is not None:
query_states, key_states, value_states, indices = unpading_input(
query_states, key_states, value_states, attention_mask
@@ -195,29 +195,44 @@ def llama_attn_forward(
query_states = query_states.view(-1, self.num_heads, self.head_dim)
key_states = key_states.view(-1, self.num_heads, self.head_dim)
value_states = value_states.view(-1, self.num_heads, self.head_dim)
else:
query_states = query_states.squeeze(dim=1)
key_states = key_states.squeeze(dim=1)
value_states = value_states.squeeze(dim=1)
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
_, _, _, block_size = k_cache.shape
if is_prompts:
attn_output = context_attention_unpadded(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
)
if attention_mask is not None:
attn_output = pad_input(attn_output, indices, bsz, q_len)
else:
attn_output = PagedAttention.pad_context_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
)
else:
if HAS_TRITON:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
# TODO Add dummy transpose and squeeze on in/out tensors of the triton flash decoding kernel
# in order to maintain compatibility. This part as well as the logics of handling query_states and attn_output
# should be revised, as we could see in previous part of `llama_attn_forward` we still have
# redundant transpose and the in/out of torch- and triton-version decoding kernel are not consistent.
query_states = query_states.transpose(1, 2)
attn_output = flash_decoding_attention(
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
)
attn_output = attn_output.squeeze(1)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if is_prompts:
attn_output = PagedAttention.pad_context_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
)
else:
attn_output = PagedAttention.pad_decoding_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
@@ -232,6 +247,15 @@ def llama_attn_forward(
@torch.no_grad()
def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
"""Generate padding position_id through attention mask.
Args:
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
Returns:
torch.Tensor: The padding position_id.
"""
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
return position_ids
@@ -239,9 +263,34 @@ def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
@torch.no_grad()
def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor):
"""Convert padding input to nopad input.
Args:
q (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim]
k (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim]
v (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim]
attention_mask (torch.Tensor): [batch_size, sequence_length]
Returns:
Tuple[torch.Tensor]: The unpad q, k, v and The index of valid data in each batch.
"""
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape
q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
return (q, k, v, indices)
@torch.no_grad()
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
if is_prompts:
index_arrays = [torch.arange(length) for length in lengths]
else:
index_arrays = [(length - 1).view(-1) for length in lengths]
indices = torch.cat(index_arrays, dim=-1)
cos_output = cos_cache[indices].to(dtype=dtype)
sin_output = sin_cache[indices].to(dtype=dtype)
return (cos_output, sin_output)