[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -20,10 +20,9 @@ def check_all_gather(process_groups_dict, rank):
tensor_to_check = torch.cat((torch.ones(2, 2), torch.zeros(2, 2)), 1).cuda()
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
process_groups_dict,
gather_dim=1,
logical_process_axis=1)
comm_spec = CommSpec(
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, process_groups_dict, gather_dim=1, logical_process_axis=1
)
sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
assert sharded_tensor_to_comm.equal(tensor_to_check)
@@ -38,10 +37,9 @@ def check_shard(process_groups_dict, rank):
tensor_to_shard = torch.cat((sharded_tensor_to_comm_0, sharded_tensor_to_comm_1), 1)
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD,
process_groups_dict,
shard_dim=1,
logical_process_axis=1)
comm_spec = CommSpec(
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, process_groups_dict, shard_dim=1, logical_process_axis=1
)
tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard)
if rank in (0, 2):
@@ -79,11 +77,13 @@ def check_all_to_all(process_groups_dict, rank):
tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD,
process_groups_dict,
gather_dim=0,
shard_dim=1,
logical_process_axis=0)
comm_spec = CommSpec(
CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD,
process_groups_dict,
gather_dim=0,
shard_dim=1,
logical_process_axis=0,
)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
@@ -124,7 +124,7 @@ def check_all_reduce_bwd(process_groups_dict, rank):
def check_comm(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
physical_mesh_id = torch.arange(0, 4)
assert rank == dist.get_rank()
@@ -157,5 +157,5 @@ def test_comm_spec():
spawn(check_comm, world_size)
if __name__ == '__main__':
if __name__ == "__main__":
test_comm_spec()

View File

@@ -8,7 +8,6 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
class TestModel(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear_1 = torch.nn.Linear(in_features, out_features)
@@ -22,9 +21,9 @@ class TestModel(torch.nn.Module):
def check_dtensor(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_model = TestModel(8, 8).to('cuda')
original_tensor = torch.rand(4, 8).to('cuda')
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
test_model = TestModel(8, 8).to("cuda")
original_tensor = torch.rand(4, 8).to("cuda")
compare_output = test_model(original_tensor)
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
@@ -39,7 +38,7 @@ def check_dtensor(rank, world_size, port):
elif rank in (2, 3):
assert d_tensor.equal(original_tensor.narrow(0, 2, 2))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
raise ValueError(f"rank {rank} is not in the device mesh")
assert to_global(d_tensor).equal(original_tensor)
output = test_model(d_tensor)
@@ -48,7 +47,7 @@ def check_dtensor(rank, world_size, port):
elif rank in (2, 3):
assert output.equal(compare_output.narrow(0, 2, 2))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
raise ValueError(f"rank {rank} is not in the device mesh")
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
d_tensor = redistribute(d_tensor, device_mesh, new_sharding_spec)
@@ -62,7 +61,7 @@ def check_dtensor(rank, world_size, port):
elif rank == 3:
assert d_tensor.equal(original_tensor.narrow(0, 3, 1))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
raise ValueError(f"rank {rank} is not in the device mesh")
dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec)
@@ -75,7 +74,7 @@ def check_dtensor(rank, world_size, port):
elif rank == 3:
assert dtensor_from_local.equal(original_tensor.narrow(0, 3, 1))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
raise ValueError(f"rank {rank} is not in the device mesh")
@rerun_if_address_is_in_use()
@@ -84,5 +83,5 @@ def test_dtensor():
spawn(check_dtensor, world_size)
if __name__ == '__main__':
if __name__ == "__main__":
test_dtensor()

View File

@@ -26,9 +26,10 @@ def test_dtensor_sharding_spec():
assert dim_spec_list_0[2].dim_diff(dim_spec_list_1[2]) == 0
assert dim_spec_list_0[3].dim_diff(dim_spec_list_1[3]) == 0
assert sharding_spec_0.spec_diff(sharding_spec_1) == \
reduce(operator.add, [dim_spec_list_0[i].dim_diff(dim_spec_list_1[i]) for i in range(dims)], 0)
assert sharding_spec_0.spec_diff(sharding_spec_1) == reduce(
operator.add, [dim_spec_list_0[i].dim_diff(dim_spec_list_1[i]) for i in range(dims)], 0
)
if __name__ == '__main__':
if __name__ == "__main__":
test_dtensor_sharding_spec()

View File

@@ -20,7 +20,7 @@ mesh_shape = (2, 2)
def check_one_step_transform(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# [[0, 1],
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
@@ -34,10 +34,10 @@ def check_one_step_transform(rank, world_size, port):
rst_dict = layout_converter.all_gather_transform_layouts(layout)
assert '[R, S1, R]' in [
assert "[R, S1, R]" in [
str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys()
]
assert '[S0, R, R]' in [
assert "[S0, R, R]" in [
str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys()
]
@@ -50,13 +50,13 @@ def check_one_step_transform(rank, world_size, port):
rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all)
assert '[S01, R, R]' in [
assert "[S01, R, R]" in [
str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
]
assert '[R, S1, S0]' in [
assert "[R, S1, S0]" in [
str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
]
assert '[S0, R, S1]' in [
assert "[S0, R, S1]" in [
str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
]
@@ -69,20 +69,20 @@ def check_one_step_transform(rank, world_size, port):
rst_dict_shard = layout_converter.shard_transform_layout(shard_layout)
assert '[S01, R, R]' in [
assert "[S01, R, R]" in [
str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
]
assert '[S0, S1, R]' in [
assert "[S0, S1, R]" in [
str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
]
assert '[S0, R, S1]' in [
assert "[S0, R, S1]" in [
str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
]
def check_layout_converting(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
@@ -102,8 +102,8 @@ def check_layout_converting(rank, world_size, port):
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
# check transform path
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]'
transform_path_str = "->".join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
assert transform_path_str == "[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]"
# check comm action sequence
# all-gather(S01) -> S0
@@ -123,18 +123,18 @@ def check_layout_converting(rank, world_size, port):
assert comm_action_sequence[2].logical_process_axis == 1
# checkout chached_spec_pairs_transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence
assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][0] == transform_path
assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][1] == comm_action_sequence
comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout)
assert comm_cost['forward'] == comm_cost['backward']
assert math.floor(comm_cost['total']) == math.floor(comm_cost['forward'] + comm_cost['backward'])
assert comm_cost["forward"] == comm_cost["backward"]
assert math.floor(comm_cost["total"]) == math.floor(comm_cost["forward"] + comm_cost["backward"])
def check_layout_converting_apply(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
@@ -173,5 +173,5 @@ def test_layout_converter():
spawn(check_layout_converting_apply, world_size)
if __name__ == '__main__':
if __name__ == "__main__":
test_layout_converter()