mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
fix
This commit is contained in:
@@ -607,7 +607,8 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
# attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
|
Reference in New Issue
Block a user