mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-12 13:21:47 +00:00
fix
This commit is contained in:
parent
6705dad41b
commit
3fab92166e
@ -442,8 +442,7 @@ class RingAttention(torch.autograd.Function):
|
|||||||
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
|
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
|
||||||
"""
|
"""
|
||||||
sp_size = dist.get_world_size(sp_group)
|
sp_size = dist.get_world_size(sp_group)
|
||||||
dist.get_rank(sp_group)
|
tp_size = dist.get_world_size(tp_group)
|
||||||
dist.get_world_size(tp_group)
|
|
||||||
|
|
||||||
if inner_ring_size is None:
|
if inner_ring_size is None:
|
||||||
if torch.cuda.device_count() >= dist.get_world_size():
|
if torch.cuda.device_count() >= dist.get_world_size():
|
||||||
@ -476,21 +475,38 @@ class RingAttention(torch.autograd.Function):
|
|||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
groups = int(world_size / sp_size)
|
groups = int(world_size / sp_size)
|
||||||
|
|
||||||
# Create inner ring groups
|
if tp_size > 1:
|
||||||
for group_id in range(groups):
|
for group_id in range(groups):
|
||||||
for i in range(inner_ring_size):
|
for i in range(inner_ring_size):
|
||||||
ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size))
|
ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size))
|
||||||
group = dist.new_group(ranks)
|
group = dist.new_group(ranks)
|
||||||
|
if rank in ranks:
|
||||||
|
inner_ring_group = group
|
||||||
|
for group_id in range(groups):
|
||||||
|
for i in range(num_rings):
|
||||||
|
ranks = list(range(i + group_id * num_rings, world_size, sp_size))
|
||||||
|
group = dist.new_group(ranks)
|
||||||
|
if rank in ranks:
|
||||||
|
inter_ring_group = group
|
||||||
|
else:
|
||||||
|
for i in range(sp_size // 2):
|
||||||
|
ranks = list(range((i) * num_rings, (i + 1) * num_rings, 1))
|
||||||
if rank in ranks:
|
if rank in ranks:
|
||||||
|
print(
|
||||||
|
"rank:",
|
||||||
|
rank,
|
||||||
|
"inner ranks:",
|
||||||
|
ranks,
|
||||||
|
)
|
||||||
|
group = dist.new_group(ranks)
|
||||||
inner_ring_group = group
|
inner_ring_group = group
|
||||||
|
for group_id in range(num_rings):
|
||||||
# Create inter ring groups
|
for i in range(num_rings):
|
||||||
for group_id in range(groups):
|
ranks = list(range(i + group_id * num_rings, world_size, inner_ring_size))
|
||||||
for i in range(num_rings):
|
ranks = [0, 1, 4, 5] if rank == 0 or rank == 1 or rank == 4 or rank == 5 else [2, 3, 6, 7]
|
||||||
ranks = list(range(i + group_id * num_rings, world_size, sp_size))
|
if rank in ranks:
|
||||||
group = dist.new_group(ranks)
|
group = dist.new_group(ranks)
|
||||||
if rank in ranks:
|
inter_ring_group = group
|
||||||
inter_ring_group = group
|
|
||||||
|
|
||||||
return inner_ring_group, inter_ring_group
|
return inner_ring_group, inter_ring_group
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user