This commit is contained in:
wangbluo 2024-10-10 15:17:06 +08:00
parent b635dd0669
commit f98384aef6

View File

@ -468,27 +468,29 @@ class RingAttention(torch.autograd.Function):
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!", f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
ranks=[0], ranks=[0],
) )
num_rings = sp_size // inner_ring_size
inner_ring_group = None inner_ring_group = None
inter_ring_group = None inter_ring_group = None
world_size = dist.get_world_size() world_size = dist.get_world_size()
rank = dist.get_rank() rank = dist.get_rank()
inner_rings = world_size // sp_size num_ring_size = world_size // num_rings
num_rings = sp_size // inner_ring_size num_inner_group = num_ring_size // inner_ring_size
if tp_size > 1: if tp_size > 1:
for i in range(inner_rings): # inner_ring_size = 2
for j in range(sp_size // tp_size): for i in range(num_rings):
# find inner ring group in one sp group for j in range(num_inner_group):
ranks = list(range(j + i * sp_size, j + (i + 1) * sp_size, tp_size)) # find inner ring group in one sp groups
ranks = list(range(j + i * num_ring_size, j + (i + 1) * num_ring_size, tp_size))
group = dist.new_group(ranks) group = dist.new_group(ranks)
if rank in ranks: if rank in ranks:
inner_ring_group = group inner_ring_group = group
for i in range(inner_rings): for i in range(num_rings):
for j in range(sp_size // tp_size): for j in range(num_inner_group):
ranks = list(range(j + i * (sp_size // tp_size), inner_rings + (i + 1) * sp_size, sp_size)) start = j + (i * num_inner_group)
ranks = list(range(start, start + num_ring_size + 1, num_ring_size))
group = dist.new_group(ranks) group = dist.new_group(ranks)
if rank in ranks: if rank in ranks:
inter_ring_group = group inter_ring_group = group