This commit is contained in:
wangbluo
2024-10-15 11:56:49 +08:00
parent 3dc08c8a5a
commit 6be9862aaf
6 changed files with 23 additions and 6 deletions

View File

@@ -572,6 +572,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
sp_axis=shard_config.sp_axis,
)
elif shard_config.enable_flash_attention: