mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -2,7 +2,7 @@ import torch
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
physical_mesh_id = torch.arange(0, 16)
|
||||
mesh_shape = (4, 4)
|
||||
@@ -16,7 +16,6 @@ shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
def test_one_step_transform():
|
||||
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,S1,R
|
||||
@@ -28,16 +27,14 @@ def test_one_step_transform():
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec:
|
||||
# shard_sequence: S0,R,R
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)}
|
||||
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {
|
||||
"forward": 0,
|
||||
"backward": 0,
|
||||
"total": 0
|
||||
})
|
||||
rst_dict = shape_consistency_manager.get_all_all_gather_spec(
|
||||
sharding_spec, {"forward": 0, "backward": 0, "total": 0}
|
||||
)
|
||||
|
||||
assert '[R, S1, R]' in [
|
||||
assert "[R, S1, R]" in [
|
||||
str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys()
|
||||
]
|
||||
assert '[S0, R, R]' in [
|
||||
assert "[S0, R, R]" in [
|
||||
str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys()
|
||||
]
|
||||
|
||||
@@ -53,19 +50,17 @@ def test_one_step_transform():
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec:
|
||||
# shard_sequence: S0,R,S1
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)}
|
||||
rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, {
|
||||
"forward": 0,
|
||||
"backward": 0,
|
||||
"total": 0
|
||||
})
|
||||
rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(
|
||||
sharding_spec_all2all, {"forward": 0, "backward": 0, "total": 0}
|
||||
)
|
||||
|
||||
assert '[S01, R, R]' in [
|
||||
assert "[S01, R, R]" in [
|
||||
str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
|
||||
]
|
||||
assert '[R, S1, S0]' in [
|
||||
assert "[R, S1, S0]" in [
|
||||
str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
|
||||
]
|
||||
assert '[S0, R, S1]' in [
|
||||
assert "[S0, R, S1]" in [
|
||||
str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
|
||||
]
|
||||
|
||||
@@ -81,19 +76,17 @@ def test_one_step_transform():
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec:
|
||||
# shard_sequence: S0,R,S1
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)}
|
||||
rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, {
|
||||
"forward": 0,
|
||||
"backward": 0,
|
||||
"total": 0
|
||||
})
|
||||
rst_dict_shard = shape_consistency_manager.get_all_shard_spec(
|
||||
sharding_spec_shard, {"forward": 0, "backward": 0, "total": 0}
|
||||
)
|
||||
|
||||
assert '[S01, R, R]' in [
|
||||
assert "[S01, R, R]" in [
|
||||
str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
|
||||
]
|
||||
assert '[S0, S1, R]' in [
|
||||
assert "[S0, S1, R]" in [
|
||||
str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
|
||||
]
|
||||
assert '[S0, R, S1]' in [
|
||||
assert "[S0, R, S1]" in [
|
||||
str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
|
||||
]
|
||||
|
||||
@@ -113,10 +106,11 @@ def test_shape_consistency():
|
||||
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
|
||||
|
||||
transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||
sharding_spec_source, sharding_spec_target)
|
||||
sharding_spec_source, sharding_spec_target
|
||||
)
|
||||
|
||||
transform_path_str = '->'.join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path])
|
||||
assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]'
|
||||
transform_path_str = "->".join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path])
|
||||
assert transform_path_str == "[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]"
|
||||
|
||||
# all-gather(S01) -> S0
|
||||
assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
|
||||
@@ -134,12 +128,15 @@ def test_shape_consistency():
|
||||
assert comm_action_sequence[2].shard_dim == 0
|
||||
assert comm_action_sequence[2].logical_process_axis == 1
|
||||
|
||||
assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]',
|
||||
'[S01, R, R]')][0] == transform_path
|
||||
assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]',
|
||||
'[S01, R, R]')][1] == comm_action_sequence
|
||||
assert (
|
||||
shape_consistency_manager.cached_spec_pairs_transform_path[("[R, S01, R]", "[S01, R, R]")][0] == transform_path
|
||||
)
|
||||
assert (
|
||||
shape_consistency_manager.cached_spec_pairs_transform_path[("[R, S01, R]", "[S01, R, R]")][1]
|
||||
== comm_action_sequence
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_one_step_transform()
|
||||
test_shape_consistency()
|
||||
|
Reference in New Issue
Block a user