mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
Fixed a bug in the inference frame
This commit is contained in:
committed by
FrankLeeeee
parent
86853a37d5
commit
62fd08ee44
@@ -70,7 +70,10 @@ def llama_model_forward(
|
||||
seq_length = input_ids.shape[1]
|
||||
device = input_ids.device
|
||||
|
||||
past_key_values_length = len(block_tables.shape[1])
|
||||
if batch.is_prompts:
|
||||
past_key_values_length = 0
|
||||
else:
|
||||
past_key_values_length = sequence_lengths[0].item() - 1
|
||||
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
@@ -163,26 +166,17 @@ def llama_attn_forward(
|
||||
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
||||
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
||||
|
||||
block_size = k_cache.shape[-1]
|
||||
k_cache.shape[-1]
|
||||
|
||||
memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size)
|
||||
# memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size, sequence_lengths)
|
||||
|
||||
if is_prompts:
|
||||
attn_output = context_attention_unpadded(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
||||
)
|
||||
else:
|
||||
attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
|
||||
decoding_attention(
|
||||
query_states,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
sequence_lengths,
|
||||
attn_output,
|
||||
block_tables.shape[1],
|
||||
block_size,
|
||||
)
|
||||
# if is_prompts:
|
||||
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
|
||||
# else:
|
||||
# attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
|
||||
# decoding_attention(query_states, k_cache, v_cache, block_tables, sequence_lengths, attn_output, block_tables.shape[1], block_size)
|
||||
|
||||
attn_output = query_states
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
@@ -190,19 +184,3 @@ def llama_attn_forward(
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
def memcpy_to_block(key, value, k_cache, v_cache, block_tables, block_size):
|
||||
block_table_list = block_tables.tolist()
|
||||
batch_size, seq_len, num_heads, head_dim = key
|
||||
|
||||
reshape_key = key.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1)
|
||||
reshape_value = value.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1)
|
||||
if seq_len == 1:
|
||||
for i in range(batch_size):
|
||||
k_cache[block_table_list[i][-1], :] = reshape_key[i]
|
||||
v_cache[block_table_list[i][-1], :] = reshape_value[i]
|
||||
else:
|
||||
for i in range(batch_size):
|
||||
k_cache[block_table_list[i], :] = reshape_key[i]
|
||||
v_cache[block_table_list[i], :] = reshape_value[i]
|
||||
|
Reference in New Issue
Block a user