[Fix] Fix spec-dec Glide LlamaModel for compatibility with transformers (#5837)

* fix glide llama model

* revise
This commit is contained in:
Yuanheng Zhao
2024-06-19 15:37:53 +08:00
committed by GitHub
parent fd1dc417d8
commit 7b249c76e5
4 changed files with 7 additions and 1 deletions

View File

@@ -319,7 +319,8 @@ class LlamaCrossAttention(nn.Module):
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
# for RoPE
cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len + 32)
position_ids = position_ids + glide_input.n_spec_tokens
cos, sin = self.rotary_emb(query_states, position_ids)
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)