mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -29,10 +29,9 @@ def check_all_gather(device_mesh, rank):
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=1,
|
||||
logical_process_axis=1)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=1
|
||||
)
|
||||
sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
|
||||
|
||||
assert sharded_tensor_to_comm.equal(tensor_to_check)
|
||||
@@ -101,11 +100,9 @@ def check_all_to_all(device_mesh, rank):
|
||||
sharding_spec = ShardingSpec(device_mesh, torch.Size((4, 2)), dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=0,
|
||||
shard_dim=1,
|
||||
logical_process_axis=0)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, sharding_spec, gather_dim=0, shard_dim=1, logical_process_axis=0
|
||||
)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
@@ -181,7 +178,7 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank):
|
||||
|
||||
def check_comm(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
assert rank == dist.get_rank()
|
||||
@@ -214,5 +211,5 @@ def test_comm_spec():
|
||||
spawn(check_comm, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_comm_spec()
|
||||
|
@@ -20,10 +20,9 @@ def check_all_gather(process_groups_dict, rank):
|
||||
tensor_to_check = torch.cat((torch.ones(2, 2), torch.zeros(2, 2)), 1).cuda()
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
process_groups_dict,
|
||||
gather_dim=1,
|
||||
logical_process_axis=1)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, process_groups_dict, gather_dim=1, logical_process_axis=1
|
||||
)
|
||||
sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
|
||||
|
||||
assert sharded_tensor_to_comm.equal(tensor_to_check)
|
||||
@@ -38,10 +37,9 @@ def check_shard(process_groups_dict, rank):
|
||||
tensor_to_shard = torch.cat((sharded_tensor_to_comm_0, sharded_tensor_to_comm_1), 1)
|
||||
|
||||
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD,
|
||||
process_groups_dict,
|
||||
shard_dim=1,
|
||||
logical_process_axis=1)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, process_groups_dict, shard_dim=1, logical_process_axis=1
|
||||
)
|
||||
tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard)
|
||||
|
||||
if rank in (0, 2):
|
||||
@@ -79,11 +77,13 @@ def check_all_to_all(process_groups_dict, rank):
|
||||
tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()
|
||||
|
||||
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD,
|
||||
process_groups_dict,
|
||||
gather_dim=0,
|
||||
shard_dim=1,
|
||||
logical_process_axis=0)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD,
|
||||
process_groups_dict,
|
||||
gather_dim=0,
|
||||
shard_dim=1,
|
||||
logical_process_axis=0,
|
||||
)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
@@ -124,7 +124,7 @@ def check_all_reduce_bwd(process_groups_dict, rank):
|
||||
|
||||
def check_comm(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
assert rank == dist.get_rank()
|
||||
@@ -157,5 +157,5 @@ def test_comm_spec():
|
||||
spawn(check_comm, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_comm_spec()
|
||||
|
@@ -8,7 +8,6 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
self.linear_1 = torch.nn.Linear(in_features, out_features)
|
||||
@@ -22,9 +21,9 @@ class TestModel(torch.nn.Module):
|
||||
|
||||
def check_dtensor(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
test_model = TestModel(8, 8).to('cuda')
|
||||
original_tensor = torch.rand(4, 8).to('cuda')
|
||||
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
test_model = TestModel(8, 8).to("cuda")
|
||||
original_tensor = torch.rand(4, 8).to("cuda")
|
||||
compare_output = test_model(original_tensor)
|
||||
|
||||
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
||||
@@ -39,7 +38,7 @@ def check_dtensor(rank, world_size, port):
|
||||
elif rank in (2, 3):
|
||||
assert d_tensor.equal(original_tensor.narrow(0, 2, 2))
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
raise ValueError(f"rank {rank} is not in the device mesh")
|
||||
assert to_global(d_tensor).equal(original_tensor)
|
||||
output = test_model(d_tensor)
|
||||
|
||||
@@ -48,7 +47,7 @@ def check_dtensor(rank, world_size, port):
|
||||
elif rank in (2, 3):
|
||||
assert output.equal(compare_output.narrow(0, 2, 2))
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
raise ValueError(f"rank {rank} is not in the device mesh")
|
||||
|
||||
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
|
||||
d_tensor = redistribute(d_tensor, device_mesh, new_sharding_spec)
|
||||
@@ -62,7 +61,7 @@ def check_dtensor(rank, world_size, port):
|
||||
elif rank == 3:
|
||||
assert d_tensor.equal(original_tensor.narrow(0, 3, 1))
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
raise ValueError(f"rank {rank} is not in the device mesh")
|
||||
|
||||
dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec)
|
||||
|
||||
@@ -75,7 +74,7 @@ def check_dtensor(rank, world_size, port):
|
||||
elif rank == 3:
|
||||
assert dtensor_from_local.equal(original_tensor.narrow(0, 3, 1))
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
raise ValueError(f"rank {rank} is not in the device mesh")
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@@ -84,5 +83,5 @@ def test_dtensor():
|
||||
spawn(check_dtensor, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_dtensor()
|
||||
|
@@ -26,9 +26,10 @@ def test_dtensor_sharding_spec():
|
||||
assert dim_spec_list_0[2].dim_diff(dim_spec_list_1[2]) == 0
|
||||
assert dim_spec_list_0[3].dim_diff(dim_spec_list_1[3]) == 0
|
||||
|
||||
assert sharding_spec_0.spec_diff(sharding_spec_1) == \
|
||||
reduce(operator.add, [dim_spec_list_0[i].dim_diff(dim_spec_list_1[i]) for i in range(dims)], 0)
|
||||
assert sharding_spec_0.spec_diff(sharding_spec_1) == reduce(
|
||||
operator.add, [dim_spec_list_0[i].dim_diff(dim_spec_list_1[i]) for i in range(dims)], 0
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_dtensor_sharding_spec()
|
||||
|
@@ -20,7 +20,7 @@ mesh_shape = (2, 2)
|
||||
|
||||
def check_one_step_transform(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
# [[0, 1],
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
@@ -34,10 +34,10 @@ def check_one_step_transform(rank, world_size, port):
|
||||
|
||||
rst_dict = layout_converter.all_gather_transform_layouts(layout)
|
||||
|
||||
assert '[R, S1, R]' in [
|
||||
assert "[R, S1, R]" in [
|
||||
str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys()
|
||||
]
|
||||
assert '[S0, R, R]' in [
|
||||
assert "[S0, R, R]" in [
|
||||
str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys()
|
||||
]
|
||||
|
||||
@@ -50,13 +50,13 @@ def check_one_step_transform(rank, world_size, port):
|
||||
|
||||
rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all)
|
||||
|
||||
assert '[S01, R, R]' in [
|
||||
assert "[S01, R, R]" in [
|
||||
str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
|
||||
]
|
||||
assert '[R, S1, S0]' in [
|
||||
assert "[R, S1, S0]" in [
|
||||
str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
|
||||
]
|
||||
assert '[S0, R, S1]' in [
|
||||
assert "[S0, R, S1]" in [
|
||||
str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
|
||||
]
|
||||
|
||||
@@ -69,20 +69,20 @@ def check_one_step_transform(rank, world_size, port):
|
||||
|
||||
rst_dict_shard = layout_converter.shard_transform_layout(shard_layout)
|
||||
|
||||
assert '[S01, R, R]' in [
|
||||
assert "[S01, R, R]" in [
|
||||
str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
|
||||
]
|
||||
assert '[S0, S1, R]' in [
|
||||
assert "[S0, S1, R]" in [
|
||||
str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
|
||||
]
|
||||
assert '[S0, R, S1]' in [
|
||||
assert "[S0, R, S1]" in [
|
||||
str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
|
||||
]
|
||||
|
||||
|
||||
def check_layout_converting(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
dim_partition_source = {1: [0, 1]}
|
||||
dim_partition_target = {0: [0, 1]}
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
@@ -102,8 +102,8 @@ def check_layout_converting(rank, world_size, port):
|
||||
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
|
||||
|
||||
# check transform path
|
||||
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
|
||||
assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]'
|
||||
transform_path_str = "->".join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
|
||||
assert transform_path_str == "[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]"
|
||||
|
||||
# check comm action sequence
|
||||
# all-gather(S01) -> S0
|
||||
@@ -123,18 +123,18 @@ def check_layout_converting(rank, world_size, port):
|
||||
assert comm_action_sequence[2].logical_process_axis == 1
|
||||
|
||||
# checkout chached_spec_pairs_transform_path
|
||||
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path
|
||||
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence
|
||||
assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][0] == transform_path
|
||||
assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][1] == comm_action_sequence
|
||||
|
||||
comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout)
|
||||
|
||||
assert comm_cost['forward'] == comm_cost['backward']
|
||||
assert math.floor(comm_cost['total']) == math.floor(comm_cost['forward'] + comm_cost['backward'])
|
||||
assert comm_cost["forward"] == comm_cost["backward"]
|
||||
assert math.floor(comm_cost["total"]) == math.floor(comm_cost["forward"] + comm_cost["backward"])
|
||||
|
||||
|
||||
def check_layout_converting_apply(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
dim_partition_source = {1: [0, 1]}
|
||||
dim_partition_target = {0: [0, 1]}
|
||||
@@ -173,5 +173,5 @@ def test_layout_converter():
|
||||
spawn(check_layout_converting_apply, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_layout_converter()
|
||||
|
@@ -17,12 +17,13 @@ def check_mix_gather_S0S1(device_mesh, rank):
|
||||
f_target_pair = (f, [0])
|
||||
b_target_pair = (b, [1])
|
||||
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
|
||||
tensor_slice = [4, 2] # (4, 2)
|
||||
tensor_slice = [4, 2] # (4, 2)
|
||||
rank_slice = 4
|
||||
f_start = (rank // rank_slice) * tensor_slice[0]
|
||||
b_start = (rank % rank_slice) * tensor_slice[1]
|
||||
tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0],
|
||||
b_start:b_start + tensor_slice[1]].contiguous().cuda()
|
||||
tensor_to_comm = (
|
||||
tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda()
|
||||
)
|
||||
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
|
||||
@@ -31,12 +32,14 @@ def check_mix_gather_S0S1(device_mesh, rank):
|
||||
# device_mesh_shape: (2, 4)
|
||||
source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=True,
|
||||
mix_gather=True)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=True,
|
||||
mix_gather=True,
|
||||
)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
@@ -48,12 +51,13 @@ def check_two_all_gather_S0S1(device_mesh, rank):
|
||||
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
|
||||
tensor_slice = [tensor_width // 2, tensor_width // 4] # (4, 2)
|
||||
tensor_slice = [tensor_width // 2, tensor_width // 4] # (4, 2)
|
||||
rank_slice = 4
|
||||
f_start = (rank // rank_slice) * tensor_slice[0]
|
||||
b_start = (rank % rank_slice) * tensor_slice[1]
|
||||
tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0],
|
||||
b_start:b_start + tensor_slice[1]].contiguous().cuda()
|
||||
tensor_to_comm = (
|
||||
tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda()
|
||||
)
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,S1
|
||||
@@ -61,10 +65,9 @@ def check_two_all_gather_S0S1(device_mesh, rank):
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=0,
|
||||
logical_process_axis=0)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=0
|
||||
)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
@@ -75,10 +78,9 @@ def check_two_all_gather_S0S1(device_mesh, rank):
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=1,
|
||||
logical_process_axis=1)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=1
|
||||
)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
@@ -95,8 +97,9 @@ def check_mix_gather_S1S0(device_mesh, rank):
|
||||
rank_slice = 4
|
||||
f_start = (rank % rank_slice) * tensor_slice[0]
|
||||
b_start = (rank // rank_slice) * tensor_slice[1]
|
||||
tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0],
|
||||
b_start:b_start + tensor_slice[1]].contiguous().cuda()
|
||||
tensor_to_comm = (
|
||||
tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda()
|
||||
)
|
||||
|
||||
dim_partition_dict = {0: [1], 1: [0]}
|
||||
|
||||
@@ -105,12 +108,14 @@ def check_mix_gather_S1S0(device_mesh, rank):
|
||||
# device_mesh_shape: (2, 4)
|
||||
source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=True,
|
||||
mix_gather=True)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=True,
|
||||
mix_gather=True,
|
||||
)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
@@ -120,12 +125,13 @@ def check_two_all_gather_S1S0(device_mesh, rank):
|
||||
tensor_width = 8
|
||||
tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda()
|
||||
|
||||
tensor_slice = [tensor_width // 4, tensor_width // 2] # (4, 2)
|
||||
tensor_slice = [tensor_width // 4, tensor_width // 2] # (4, 2)
|
||||
rank_slice = 4
|
||||
f_start = (rank % rank_slice) * tensor_slice[0]
|
||||
b_start = (rank // rank_slice) * tensor_slice[1]
|
||||
tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0],
|
||||
b_start:b_start + tensor_slice[1]].contiguous().cuda()
|
||||
tensor_to_comm = (
|
||||
tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda()
|
||||
)
|
||||
|
||||
dim_partition_dict = {0: [1], 1: [0]}
|
||||
|
||||
@@ -135,10 +141,9 @@ def check_two_all_gather_S1S0(device_mesh, rank):
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=0,
|
||||
logical_process_axis=1)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=1
|
||||
)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
@@ -149,10 +154,9 @@ def check_two_all_gather_S1S0(device_mesh, rank):
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=1,
|
||||
logical_process_axis=0)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=0
|
||||
)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
@@ -165,7 +169,7 @@ def check_mix_gather_S01R(device_mesh, rank):
|
||||
f_target_pair = (f, [0, 1])
|
||||
b_target_pair = (b, [])
|
||||
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
|
||||
tensor_to_comm = tensor_to_check[rank:rank + 1, :].contiguous().cuda()
|
||||
tensor_to_comm = tensor_to_check[rank : rank + 1, :].contiguous().cuda()
|
||||
|
||||
dim_partition_dict = {0: [0, 1]}
|
||||
# DistSpec:
|
||||
@@ -173,12 +177,14 @@ def check_mix_gather_S01R(device_mesh, rank):
|
||||
# device_mesh_shape: (2, 4)
|
||||
source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=True,
|
||||
mix_gather=True)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=True,
|
||||
mix_gather=True,
|
||||
)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
@@ -189,7 +195,7 @@ def check_two_all_gather_S01R(device_mesh, rank):
|
||||
tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda()
|
||||
|
||||
rank_stride = tensor_width // 8
|
||||
tensor_to_comm = tensor_to_check[rank:rank + rank_stride, :].contiguous().cuda()
|
||||
tensor_to_comm = tensor_to_check[rank : rank + rank_stride, :].contiguous().cuda()
|
||||
|
||||
dim_partition_dict = {0: [0, 1]}
|
||||
|
||||
@@ -199,10 +205,9 @@ def check_two_all_gather_S01R(device_mesh, rank):
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=0,
|
||||
logical_process_axis=1)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=1
|
||||
)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
@@ -214,10 +219,9 @@ def check_two_all_gather_S01R(device_mesh, rank):
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=0,
|
||||
logical_process_axis=0)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=0
|
||||
)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
@@ -231,7 +235,7 @@ def check_mix_gather_RS01(device_mesh, rank):
|
||||
f_target_pair = (f, [])
|
||||
b_target_pair = (b, [0, 1])
|
||||
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
|
||||
tensor_to_comm = tensor_to_check[:, rank:rank + 1].contiguous().cuda()
|
||||
tensor_to_comm = tensor_to_check[:, rank : rank + 1].contiguous().cuda()
|
||||
|
||||
dim_partition_dict = {1: [0, 1]}
|
||||
# DistSpec:
|
||||
@@ -239,12 +243,14 @@ def check_mix_gather_RS01(device_mesh, rank):
|
||||
# device_mesh_shape: (2, 4)
|
||||
source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=True,
|
||||
mix_gather=True)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=True,
|
||||
mix_gather=True,
|
||||
)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
@@ -255,7 +261,7 @@ def check_two_all_gather_RS01(device_mesh, rank):
|
||||
tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda()
|
||||
|
||||
rank_stride = tensor_width // 8
|
||||
tensor_to_comm = tensor_to_check[:, rank:rank + rank_stride].contiguous().cuda()
|
||||
tensor_to_comm = tensor_to_check[:, rank : rank + rank_stride].contiguous().cuda()
|
||||
|
||||
dim_partition_dict = {1: [0, 1]}
|
||||
|
||||
@@ -265,10 +271,9 @@ def check_two_all_gather_RS01(device_mesh, rank):
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=1,
|
||||
logical_process_axis=1)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=1
|
||||
)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
@@ -280,10 +285,9 @@ def check_two_all_gather_RS01(device_mesh, rank):
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=1,
|
||||
logical_process_axis=0)
|
||||
comm_spec = CommSpec(
|
||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=0
|
||||
)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
@@ -292,7 +296,7 @@ def check_two_all_gather_RS01(device_mesh, rank):
|
||||
|
||||
def check_comm(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
physical_mesh_id = torch.arange(0, 8)
|
||||
assert rank == dist.get_rank()
|
||||
@@ -326,5 +330,5 @@ def test_mix_gather():
|
||||
spawn(check_comm, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_mix_gather()
|
||||
|
@@ -2,7 +2,7 @@ import torch
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
physical_mesh_id = torch.arange(0, 16)
|
||||
mesh_shape = (4, 4)
|
||||
@@ -16,7 +16,6 @@ shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
def test_one_step_transform():
|
||||
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,S1,R
|
||||
@@ -28,16 +27,14 @@ def test_one_step_transform():
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec:
|
||||
# shard_sequence: S0,R,R
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)}
|
||||
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {
|
||||
"forward": 0,
|
||||
"backward": 0,
|
||||
"total": 0
|
||||
})
|
||||
rst_dict = shape_consistency_manager.get_all_all_gather_spec(
|
||||
sharding_spec, {"forward": 0, "backward": 0, "total": 0}
|
||||
)
|
||||
|
||||
assert '[R, S1, R]' in [
|
||||
assert "[R, S1, R]" in [
|
||||
str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys()
|
||||
]
|
||||
assert '[S0, R, R]' in [
|
||||
assert "[S0, R, R]" in [
|
||||
str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys()
|
||||
]
|
||||
|
||||
@@ -53,19 +50,17 @@ def test_one_step_transform():
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec:
|
||||
# shard_sequence: S0,R,S1
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)}
|
||||
rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, {
|
||||
"forward": 0,
|
||||
"backward": 0,
|
||||
"total": 0
|
||||
})
|
||||
rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(
|
||||
sharding_spec_all2all, {"forward": 0, "backward": 0, "total": 0}
|
||||
)
|
||||
|
||||
assert '[S01, R, R]' in [
|
||||
assert "[S01, R, R]" in [
|
||||
str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
|
||||
]
|
||||
assert '[R, S1, S0]' in [
|
||||
assert "[R, S1, S0]" in [
|
||||
str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
|
||||
]
|
||||
assert '[S0, R, S1]' in [
|
||||
assert "[S0, R, S1]" in [
|
||||
str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
|
||||
]
|
||||
|
||||
@@ -81,19 +76,17 @@ def test_one_step_transform():
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec:
|
||||
# shard_sequence: S0,R,S1
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)}
|
||||
rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, {
|
||||
"forward": 0,
|
||||
"backward": 0,
|
||||
"total": 0
|
||||
})
|
||||
rst_dict_shard = shape_consistency_manager.get_all_shard_spec(
|
||||
sharding_spec_shard, {"forward": 0, "backward": 0, "total": 0}
|
||||
)
|
||||
|
||||
assert '[S01, R, R]' in [
|
||||
assert "[S01, R, R]" in [
|
||||
str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
|
||||
]
|
||||
assert '[S0, S1, R]' in [
|
||||
assert "[S0, S1, R]" in [
|
||||
str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
|
||||
]
|
||||
assert '[S0, R, S1]' in [
|
||||
assert "[S0, R, S1]" in [
|
||||
str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
|
||||
]
|
||||
|
||||
@@ -113,10 +106,11 @@ def test_shape_consistency():
|
||||
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
|
||||
|
||||
transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||
sharding_spec_source, sharding_spec_target)
|
||||
sharding_spec_source, sharding_spec_target
|
||||
)
|
||||
|
||||
transform_path_str = '->'.join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path])
|
||||
assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]'
|
||||
transform_path_str = "->".join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path])
|
||||
assert transform_path_str == "[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]"
|
||||
|
||||
# all-gather(S01) -> S0
|
||||
assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
|
||||
@@ -134,12 +128,15 @@ def test_shape_consistency():
|
||||
assert comm_action_sequence[2].shard_dim == 0
|
||||
assert comm_action_sequence[2].logical_process_axis == 1
|
||||
|
||||
assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]',
|
||||
'[S01, R, R]')][0] == transform_path
|
||||
assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]',
|
||||
'[S01, R, R]')][1] == comm_action_sequence
|
||||
assert (
|
||||
shape_consistency_manager.cached_spec_pairs_transform_path[("[R, S01, R]", "[S01, R, R]")][0] == transform_path
|
||||
)
|
||||
assert (
|
||||
shape_consistency_manager.cached_spec_pairs_transform_path[("[R, S01, R]", "[S01, R, R]")][1]
|
||||
== comm_action_sequence
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_one_step_transform()
|
||||
test_shape_consistency()
|
||||
|
@@ -4,14 +4,14 @@ import torch
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_apply(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
@@ -72,5 +72,5 @@ def test_apply():
|
||||
spawn(check_apply, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_apply()
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
def test_sharding_spec():
|
||||
@@ -21,5 +21,5 @@ def test_sharding_spec():
|
||||
assert str(sharding_spec.sharding_sequence) == "[S01, R, R]"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_sharding_spec()
|
||||
|
Reference in New Issue
Block a user