This commit is contained in:
flybird11111
2025-04-24 15:44:20 +08:00
parent 686982764c
commit e891501c55
3 changed files with 9 additions and 8 deletions

View File

@@ -607,7 +607,8 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
# attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)