diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 7af8271eb..411b4dbcc 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -478,6 +478,7 @@ class RingAttention(torch.autograd.Function): num_ring_size = world_size // num_rings if tp_size > 1: + # Create inner ring groups ranks = [] for i in range(num_rings): start = i * num_ring_size @@ -494,7 +495,7 @@ class RingAttention(torch.autograd.Function): group = dist.new_group(inner_rank) if rank in inner_rank: inner_ring_group = group - + # Create inter ring groups for i in range(num_ring_size): inter_rank = [i + j * num_ring_size for j in range(num_rings)] group = dist.new_group(inter_rank)