mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
adapted to pad_context_forward
This commit is contained in:
committed by
FrankLeeeee
parent
47e53eaa1c
commit
fa4fbdbffb
@@ -16,7 +16,7 @@ from transformers.models.llama.modeling_llama import (
|
||||
from colossalai.inference.modeling.layers.attention import convert_kvcache, copy_to_cache
|
||||
from colossalai.inference.struct import BatchInfo
|
||||
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
|
||||
from flash_attn.bert_padding import index_first_axis # noqa
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
@@ -167,20 +167,8 @@ def llama_attn_forward(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||
)
|
||||
else:
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attn_output = pad_decoding_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
attention_mask,
|
||||
self.layer_idx,
|
||||
self.attention_dropout,
|
||||
self.training,
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
@@ -215,9 +203,6 @@ def pad_decoding_forward(
|
||||
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
|
||||
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
||||
attn_mask: torch.Tensor = None,
|
||||
layer_id: int = 0,
|
||||
attention_dropout: float = None,
|
||||
training: bool = False,
|
||||
):
|
||||
bsz, query_length, num_heads, head_size = query.shape
|
||||
seq_len = max(lengths)
|
||||
@@ -247,9 +232,7 @@ def pad_decoding_forward(
|
||||
attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(query.dtype).min)
|
||||
|
||||
attn_weights += attn_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=attention_dropout, training=training)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
||||
if attn_output.size() != (bsz, num_heads, 1, head_size):
|
||||
@@ -277,8 +260,6 @@ def pad_context_forward(
|
||||
block_size = k_cache.shape[-1]
|
||||
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
|
||||
block_tables.shape[-1] * block_size
|
||||
shape = (bsz, seq_len, num_heads, head_size)
|
||||
input_shape = shape[:2]
|
||||
|
||||
# Copy kv to memory(rotary embedded)
|
||||
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
|
||||
|
Reference in New Issue
Block a user