Revert "[sync] sync feature/shardformer with develop"

This commit is contained in:
Frank Lee
2023-06-09 09:41:27 +08:00
committed by GitHub
parent 24651fdd4f
commit ddcf58cacf
48 changed files with 445 additions and 3876 deletions

View File

@@ -1,19 +1,20 @@
import torch
from colossalai.device.device_mesh import DeviceMesh
import torch
def test_device_mesh():
physical_mesh_id = torch.arange(0, 16)
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
assert device_mesh.global_rank_to_local_rank(5) == [1, 1]
assert device_mesh.global_rank_to_local_rank(11) == [2, 3]
assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3]
assert device_mesh.convert_map[5] == [1, 1]
assert device_mesh.convert_map[11] == [2, 3]
assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]]
assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]]
assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3]
if __name__ == '__main__':