This commit is contained in:
flybird11111
2025-04-24 16:20:42 +08:00
parent d7a9eb0f67
commit 2f615a49fd
5 changed files with 2 additions and 26 deletions

View File

@@ -515,13 +515,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
if is_share_sp_tp(sp_mode):
q_len *= sp_size
# if sp_mode == "all_to_all":
# # query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
# # key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
# # value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
# # bsz, q_len, _ = query_states.size()
# # hidden_states = all_to_all_comm(hidden_states, sp_group, fp8_communication=shard_config.fp8_communication)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
@@ -548,7 +541,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# cos, sin = self.rotary_emb(value_states, position_ids)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@@ -607,14 +599,12 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
# attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
# return attn_output, attn_weights, past_key_value
return attn_output, attn_weights
return forward