This commit is contained in:
wangbluo
2024-09-25 19:02:21 +08:00
parent 91ed32c256
commit 6705dad41b
3 changed files with 3 additions and 3 deletions

View File

@@ -571,9 +571,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
key_states,
value_states,
sp_group,
tp_group=tp_group,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
tp_group=tp_group,
)
elif shard_config.enable_flash_attention: