mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
precision alignment
This commit is contained in:
committed by
FrankLeeeee
parent
62968588d1
commit
9489dc64d8
@@ -67,19 +67,8 @@ def llama_model_forward(
|
||||
block_tables = batch.get_block_table_tensor()
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
|
||||
seq_length = input_ids.shape[1]
|
||||
device = input_ids.device
|
||||
|
||||
if batch.is_prompts:
|
||||
past_key_values_length = 0
|
||||
else:
|
||||
past_key_values_length = sequence_lengths[0].item() - 1
|
||||
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
|
||||
# Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer.
|
||||
position_ids = generate_padding_position_id(input_ids)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
@@ -142,7 +131,7 @@ def llama_attn_forward(
|
||||
k_cache: torch.Tensor = None,
|
||||
v_cache: torch.Tensor = None,
|
||||
is_prompts: bool = True,
|
||||
sequence_lengths: int = None,
|
||||
sequence_lengths: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@@ -150,7 +139,9 @@ def llama_attn_forward(
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2] + block_tables.shape[1]
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if not is_prompts:
|
||||
kv_seq_len = kv_seq_len + sequence_lengths[0].item()
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
@@ -166,10 +157,8 @@ def llama_attn_forward(
|
||||
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
||||
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
||||
|
||||
k_cache.shape[-1]
|
||||
|
||||
# TODO: The code below will be uncommented after the development of attention-related kernel is completed.
|
||||
# memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size, sequence_lengths)
|
||||
|
||||
# if is_prompts:
|
||||
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
|
||||
# else:
|
||||
@@ -177,10 +166,16 @@ def llama_attn_forward(
|
||||
# decoding_attention(query_states, k_cache, v_cache, block_tables, sequence_lengths, attn_output, block_tables.shape[1], block_size)
|
||||
|
||||
attn_output = query_states
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
def generate_padding_position_id(input_ids: torch.Tensor) -> torch.Tensor:
|
||||
padding_id = 2
|
||||
attention_mask = input_ids.ne(padding_id).long()
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
return position_ids
|
||||
|
Reference in New Issue
Block a user