mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
Revert "[sync] sync feature/shardformer with develop"
This commit is contained in:
@@ -1,19 +1,20 @@
|
||||
import torch
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
import torch
|
||||
|
||||
|
||||
def test_device_mesh():
|
||||
physical_mesh_id = torch.arange(0, 16)
|
||||
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
|
||||
mesh_shape = (4, 4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
# [8, 9, 10,11],
|
||||
# [12,13,14,15]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
assert device_mesh.global_rank_to_local_rank(5) == [1, 1]
|
||||
assert device_mesh.global_rank_to_local_rank(11) == [2, 3]
|
||||
assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3]
|
||||
assert device_mesh.convert_map[5] == [1, 1]
|
||||
assert device_mesh.convert_map[11] == [2, 3]
|
||||
assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]]
|
||||
assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]]
|
||||
assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -20,12 +20,16 @@ def check_layer(rank, world_size, port):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]}
|
||||
logical_process_groups = device_mesh.process_groups_dict
|
||||
|
||||
for axis in range(len(mesh_shape)):
|
||||
tensor = torch.ones(4).cuda()
|
||||
pg = device_mesh.get_process_group(axis=axis)
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg)
|
||||
assert tensor.equal(tensor_to_check)
|
||||
for mesh_dim, pgs in logical_pg_dict.items():
|
||||
for index, pg in enumerate(pgs):
|
||||
if rank in pg:
|
||||
tensor = torch.ones(4).cuda()
|
||||
group = logical_process_groups[mesh_dim][index][1]
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=group)
|
||||
assert tensor.equal(tensor_to_check)
|
||||
|
||||
gpc.destroy()
|
||||
|
||||
|
Reference in New Issue
Block a user