[autoparallel] change the following nodes strategies generation logic (#1636)

* [autoparallel] change the following nodes strategies generation logic

* fix unit test
This commit is contained in:
YuliangLiu0306
2022-09-27 11:20:52 +08:00
committed by GitHub
parent 59f100510a
commit 03978aad45
2 changed files with 9 additions and 8 deletions

View File

@@ -59,7 +59,7 @@ def test_strategies_constructor():
assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder'
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]'
assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]_0'
# Third node is conv.
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
@@ -79,7 +79,7 @@ def test_strategies_constructor():
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
mul = nodes[1]
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]'
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]_0'
# Third node is conv.
conv = nodes[2]