mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 15:57:16 +00:00
fix
This commit is contained in:
parent
b635dd0669
commit
f98384aef6
@ -468,27 +468,29 @@ class RingAttention(torch.autograd.Function):
|
|||||||
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
|
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
|
||||||
ranks=[0],
|
ranks=[0],
|
||||||
)
|
)
|
||||||
|
num_rings = sp_size // inner_ring_size
|
||||||
inner_ring_group = None
|
inner_ring_group = None
|
||||||
inter_ring_group = None
|
inter_ring_group = None
|
||||||
|
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
|
|
||||||
inner_rings = world_size // sp_size
|
num_ring_size = world_size // num_rings
|
||||||
num_rings = sp_size // inner_ring_size
|
num_inner_group = num_ring_size // inner_ring_size
|
||||||
|
|
||||||
if tp_size > 1:
|
if tp_size > 1:
|
||||||
for i in range(inner_rings):
|
# inner_ring_size = 2
|
||||||
for j in range(sp_size // tp_size):
|
for i in range(num_rings):
|
||||||
# find inner ring group in one sp group
|
for j in range(num_inner_group):
|
||||||
ranks = list(range(j + i * sp_size, j + (i + 1) * sp_size, tp_size))
|
# find inner ring group in one sp groups
|
||||||
|
ranks = list(range(j + i * num_ring_size, j + (i + 1) * num_ring_size, tp_size))
|
||||||
group = dist.new_group(ranks)
|
group = dist.new_group(ranks)
|
||||||
if rank in ranks:
|
if rank in ranks:
|
||||||
inner_ring_group = group
|
inner_ring_group = group
|
||||||
for i in range(inner_rings):
|
for i in range(num_rings):
|
||||||
for j in range(sp_size // tp_size):
|
for j in range(num_inner_group):
|
||||||
ranks = list(range(j + i * (sp_size // tp_size), inner_rings + (i + 1) * sp_size, sp_size))
|
start = j + (i * num_inner_group)
|
||||||
|
ranks = list(range(start, start + num_ring_size + 1, num_ring_size))
|
||||||
group = dist.new_group(ranks)
|
group = dist.new_group(ranks)
|
||||||
if rank in ranks:
|
if rank in ranks:
|
||||||
inter_ring_group = group
|
inter_ring_group = group
|
||||||
|
Loading…
Reference in New Issue
Block a user