diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 15ad09baa..2cc6f3163 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -500,8 +500,8 @@ class RingAttention(torch.autograd.Function): k, v, sp_group, - tp_group: Optional[dist.ProcessGroup], attention_mask_type, + tp_group=None, cu_seqlens=None, max_seqlen=None, valid_indices=None, diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 01d47bcd0..6be75a3c6 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -866,7 +866,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) key, value, sp_group, - tp_group, + tp_group=tp_group, **attention_mask, dropout_p=dropout_p, scale=scale, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index b5f505dce..08f4bc90d 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -571,7 +571,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s key_states, value_states, sp_group, - tp_group, + tp_group=tp_group, **attention_mask, inner_ring_size=shard_config.inner_ring_size, )