[autoparallel] add cost graph class (#1481)

* [autoparallel] add cost graph class

* polish code
This commit is contained in:
YuliangLiu0306
2022-08-25 17:19:59 +08:00
committed by GitHub
parent 4b03c25f85
commit 413c053453
6 changed files with 141 additions and 5 deletions

View File

@@ -70,7 +70,9 @@ def test_conv_handler():
sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=entire_shape,
sharding_sequence=sharding_sequence)
strategies_vector_for_input.append(sharding_spec)
strategy_name = str(sharding_spec.sharding_sequence)
sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec)
strategies_vector_for_input.append(sharding_strategy)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
# generate conv strategy

View File

@@ -69,7 +69,9 @@ def test_dot_handler():
sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=entire_shape,
sharding_sequence=sharding_sequence)
strategies_vector_for_input.append(sharding_spec)
strategy_name = str(sharding_spec.sharding_sequence)
sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec)
strategies_vector_for_input.append(sharding_strategy)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
# generate dot strategy