diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index c49113a24..1f39a44b3 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -623,7 +623,6 @@ class RingAttention(torch.autograd.Function): is_packed: Optional[bool] = False, inner_ring_group: Optional[dist.ProcessGroup] = None, inter_ring_group: Optional[dist.ProcessGroup] = None, - tp_group: Optional[dist.ProcessGroup] = None, ): cu_seqlens_q = cu_seqlens_kv = cu_seqlens @@ -1120,7 +1119,7 @@ class RingAttention(torch.autograd.Function): if not is_packed: dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)] - return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None) + return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None) @staticmethod def prepare_varlen_batch(