adapted to pad_context_forward

This commit is contained in:
yuehuayingxueluo
2024-01-09 13:52:53 +08:00
committed by FrankLeeeee
parent 47e53eaa1c
commit fa4fbdbffb
9 changed files with 42 additions and 41 deletions

View File

@@ -16,7 +16,7 @@ from transformers.models.llama.modeling_llama import (
from colossalai.inference.modeling.layers.attention import convert_kvcache, copy_to_cache
from colossalai.inference.struct import BatchInfo
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
from flash_attn.bert_padding import index_first_axis # noqa
def rotate_half(x):
@@ -167,20 +167,8 @@ def llama_attn_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
)
else:
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output = pad_decoding_forward(
query_states,
key_states,
value_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
attention_mask,
self.layer_idx,
self.attention_dropout,
self.training,
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
)
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
@@ -215,9 +203,6 @@ def pad_decoding_forward(
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
attn_mask: torch.Tensor = None,
layer_id: int = 0,
attention_dropout: float = None,
training: bool = False,
):
bsz, query_length, num_heads, head_size = query.shape
seq_len = max(lengths)
@@ -247,9 +232,7 @@ def pad_decoding_forward(
attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(query.dtype).min)
attn_weights += attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=attention_dropout, training=training)
attn_output = torch.matmul(attn_weights, value)
if attn_output.size() != (bsz, num_heads, 1, head_size):
@@ -277,8 +260,6 @@ def pad_context_forward(
block_size = k_cache.shape[-1]
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
block_tables.shape[-1] * block_size
shape = (bsz, seq_len, num_heads, head_size)
input_shape = shape[:2]
# Copy kv to memory(rotary embedded)
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)