mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[autoparallel] update CommSpec (#1667)
This commit is contained in:
@@ -27,7 +27,11 @@ def test_one_step_transform():
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec:
|
||||
# shard_sequence: S0,R,R
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)}
|
||||
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0)
|
||||
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {
|
||||
"forward": 0,
|
||||
"backward": 0,
|
||||
"total": 0
|
||||
})
|
||||
|
||||
assert '[R, S1, R]' in [
|
||||
str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys()
|
||||
@@ -48,7 +52,11 @@ def test_one_step_transform():
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec:
|
||||
# shard_sequence: S0,R,S1
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)}
|
||||
rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, 0)
|
||||
rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, {
|
||||
"forward": 0,
|
||||
"backward": 0,
|
||||
"total": 0
|
||||
})
|
||||
|
||||
assert '[S01, R, R]' in [
|
||||
str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
|
||||
@@ -72,7 +80,11 @@ def test_one_step_transform():
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec:
|
||||
# shard_sequence: S0,R,S1
|
||||
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)}
|
||||
rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, 0)
|
||||
rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, {
|
||||
"forward": 0,
|
||||
"backward": 0,
|
||||
"total": 0
|
||||
})
|
||||
|
||||
assert '[S01, R, R]' in [
|
||||
str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
|
||||
|
Reference in New Issue
Block a user