Fixed a bug in the inference frame

This commit is contained in:
yuehuayingxueluo
2023-12-26 21:34:27 +08:00
committed by FrankLeeeee
parent 86853a37d5
commit 62fd08ee44
8 changed files with 261 additions and 90 deletions

View File

@@ -70,7 +70,10 @@ def llama_model_forward(
seq_length = input_ids.shape[1]
device = input_ids.device
past_key_values_length = len(block_tables.shape[1])
if batch.is_prompts:
past_key_values_length = 0
else:
past_key_values_length = sequence_lengths[0].item() - 1
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
@@ -163,26 +166,17 @@ def llama_attn_forward(
key_states = key_states.view(-1, self.num_heads, self.head_dim)
value_states = value_states.view(-1, self.num_heads, self.head_dim)
block_size = k_cache.shape[-1]
k_cache.shape[-1]
memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size)
# memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size, sequence_lengths)
if is_prompts:
attn_output = context_attention_unpadded(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
)
else:
attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
decoding_attention(
query_states,
k_cache,
v_cache,
block_tables,
sequence_lengths,
attn_output,
block_tables.shape[1],
block_size,
)
# if is_prompts:
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
# else:
# attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
# decoding_attention(query_states, k_cache, v_cache, block_tables, sequence_lengths, attn_output, block_tables.shape[1], block_size)
attn_output = query_states
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@@ -190,19 +184,3 @@ def llama_attn_forward(
attn_output = self.o_proj(attn_output)
return attn_output
def memcpy_to_block(key, value, k_cache, v_cache, block_tables, block_size):
block_table_list = block_tables.tolist()
batch_size, seq_len, num_heads, head_dim = key
reshape_key = key.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1)
reshape_value = value.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1)
if seq_len == 1:
for i in range(batch_size):
k_cache[block_table_list[i][-1], :] = reshape_key[i]
v_cache[block_table_list[i][-1], :] = reshape_value[i]
else:
for i in range(batch_size):
k_cache[block_table_list[i], :] = reshape_key[i]
v_cache[block_table_list[i], :] = reshape_value[i]