mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
fix
This commit is contained in:
@@ -515,13 +515,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
if is_share_sp_tp(sp_mode):
|
||||
q_len *= sp_size
|
||||
|
||||
# if sp_mode == "all_to_all":
|
||||
# # query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
# # key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
# # value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
# # bsz, q_len, _ = query_states.size()
|
||||
# # hidden_states = all_to_all_comm(hidden_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
@@ -548,7 +541,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
# cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
@@ -607,14 +599,12 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
# attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
# return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights
|
||||
|
||||
return forward
|
||||
|
Reference in New Issue
Block a user