This commit is contained in:
wangbluo
2024-10-15 11:01:34 +08:00
parent 8ff7d0c780
commit 3dc08c8a5a
6 changed files with 18 additions and 15 deletions

View File

@@ -571,6 +571,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
sp_group,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
)
elif shard_config.enable_flash_attention: