Merge pull request #6071 from wangbluo/ring_attention

[Ring Attention] fix the 2d ring attn when using multiple machine
This commit is contained in:
Wang Binluo
2024-10-15 15:17:21 +08:00
committed by GitHub
6 changed files with 41 additions and 31 deletions

View File

@@ -569,9 +569,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
query_states,
key_states,
value_states,
sp_group,
sp_axis=shard_config.sp_axis,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
)
elif shard_config.enable_flash_attention: