1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-05-05 15:08:18 +00:00
This commit is contained in:
wangbluo 2024-09-25 19:02:21 +08:00
parent 91ed32c256
commit 6705dad41b
3 changed files with 3 additions and 3 deletions
colossalai/shardformer

View File

@ -501,7 +501,6 @@ class RingAttention(torch.autograd.Function):
v,
sp_group,
attention_mask_type,
tp_group=None,
cu_seqlens=None,
max_seqlen=None,
valid_indices=None,
@ -510,6 +509,7 @@ class RingAttention(torch.autograd.Function):
deterministic=False,
return_softmax=False,
inner_ring_size=None,
tp_group=None,
**kwargs,
):
"""

View File

@ -866,11 +866,11 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
key,
value,
sp_group,
tp_group=tp_group,
**attention_mask,
dropout_p=dropout_p,
scale=scale,
inner_ring_size=shard_config.inner_ring_size,
tp_group=tp_group,
)
else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)

View File

@ -571,9 +571,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
key_states,
value_states,
sp_group,
tp_group=tp_group,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
tp_group=tp_group,
)
elif shard_config.enable_flash_attention: