Revert "[sync] sync feature/shardformer with develop"

This commit is contained in:
Frank Lee
2023-06-09 09:41:27 +08:00
committed by GitHub
parent 24651fdd4f
commit ddcf58cacf
48 changed files with 445 additions and 3876 deletions

View File

@@ -20,12 +20,16 @@ def check_layer(rank, world_size, port):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]}
logical_process_groups = device_mesh.process_groups_dict
for axis in range(len(mesh_shape)):
tensor = torch.ones(4).cuda()
pg = device_mesh.get_process_group(axis=axis)
dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg)
assert tensor.equal(tensor_to_check)
for mesh_dim, pgs in logical_pg_dict.items():
for index, pg in enumerate(pgs):
if rank in pg:
tensor = torch.ones(4).cuda()
group = logical_process_groups[mesh_dim][index][1]
dist.all_reduce(tensor, op=ReduceOp.SUM, group=group)
assert tensor.equal(tensor_to_check)
gpc.destroy()