From e1e86f9f1fdc7a885d9183b19567e440e3e8b34f Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 14 Oct 2024 11:45:35 +0800 Subject: [PATCH] fix --- colossalai/shardformer/layer/attn.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 411b4dbcc..aa013e526 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -484,12 +484,7 @@ class RingAttention(torch.autograd.Function): start = i * num_ring_size end = (i + 1) * num_ring_size for idx in range(start, end): - inner_rank = [] - for k in range(inner_ring_size): - current_num = idx + k * tp_size - if current_num >= end: - break - inner_rank.append(current_num) + inner_rank = [idx + k * tp_size for k in range(inner_ring_size) if idx + k * tp_size < end] if len(inner_rank) == inner_ring_size and inner_rank not in ranks: ranks.append(inner_rank) group = dist.new_group(inner_rank)