From 6705dad41b001b75a94f1f92da48df666f5e7b55 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 25 Sep 2024 19:02:21 +0800 Subject: [PATCH] fix --- colossalai/shardformer/layer/attn.py | 2 +- colossalai/shardformer/modeling/gpt2.py | 2 +- colossalai/shardformer/modeling/llama.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 2cc6f3163..b7738c7e2 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -501,7 +501,6 @@ class RingAttention(torch.autograd.Function): v, sp_group, attention_mask_type, - tp_group=None, cu_seqlens=None, max_seqlen=None, valid_indices=None, @@ -510,6 +509,7 @@ class RingAttention(torch.autograd.Function): deterministic=False, return_softmax=False, inner_ring_size=None, + tp_group=None, **kwargs, ): """ diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 6be75a3c6..43b3d2c5e 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -866,11 +866,11 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) key, value, sp_group, - tp_group=tp_group, **attention_mask, dropout_p=dropout_p, scale=scale, inner_ring_size=shard_config.inner_ring_size, + tp_group=tp_group, ) else: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 08f4bc90d..aa761fd21 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -571,9 +571,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s key_states, value_states, sp_group, - tp_group=tp_group, **attention_mask, inner_ring_size=shard_config.inner_ring_size, + tp_group=tp_group, ) elif shard_config.enable_flash_attention: