From 0002ae5956bfbab15810ce95ea8db70865fb80f1 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 11 Oct 2024 14:16:21 +0800 Subject: [PATCH] fix --- colossalai/shardformer/layer/attn.py | 37 +++++++++++++++++----------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 7c25bee1a..7af8271eb 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -476,24 +476,31 @@ class RingAttention(torch.autograd.Function): rank = dist.get_rank() num_ring_size = world_size // num_rings - num_inner_group = num_ring_size // inner_ring_size if tp_size > 1: + ranks = [] for i in range(num_rings): - for j in range(num_inner_group): - # find inner ring group in one sp groups - start = j + i * num_ring_size - ranks = list(range(start, start + tp_size * inner_ring_size, tp_size)) - group = dist.new_group(ranks) - if rank in ranks: - inner_ring_group = group - for i in range(num_rings): - for j in range(num_inner_group): - start = j + (i * num_inner_group) - ranks = list(range(start, start + num_ring_size + 1, num_ring_size)) - group = dist.new_group(ranks) - if rank in ranks: - inter_ring_group = group + start = i * num_ring_size + end = (i + 1) * num_ring_size + for idx in range(start, end): + inner_rank = [] + for k in range(inner_ring_size): + current_num = idx + k * tp_size + if current_num >= end: + break + inner_rank.append(current_num) + if len(inner_rank) == inner_ring_size and inner_rank not in ranks: + ranks.append(inner_rank) + group = dist.new_group(inner_rank) + if rank in inner_rank: + inner_ring_group = group + + for i in range(num_ring_size): + inter_rank = [i + j * num_ring_size for j in range(num_rings)] + group = dist.new_group(inter_rank) + if rank in inter_rank: + inter_ring_group = group + else: # Create inner ring groups for i in range(inner_ring_size):