This commit is contained in:
wangbluo
2024-10-14 18:01:53 +08:00
parent d891e50617
commit 23199e34cc
4 changed files with 19 additions and 55 deletions

View File

@@ -563,8 +563,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
tp_group = shard_config.tensor_parallel_process_group
if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query_states,
@@ -573,7 +571,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
sp_group,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
tp_group=tp_group,
)
elif shard_config.enable_flash_attention: