This commit is contained in:
wangbluo
2024-10-15 13:26:44 +08:00
parent 6be9862aaf
commit fd92789af2
4 changed files with 9 additions and 12 deletions

View File

@@ -568,11 +568,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
query_states,
key_states,
value_states,
sp_group,
sp_axis=shard_config.sp_axis,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
sp_axis=shard_config.sp_axis,
)
elif shard_config.enable_flash_attention: