[fx] PoC of runtime shape consistency application (#1607)

* [fx] PoC of runtime shape consistency application

* polish code
This commit is contained in:
YuliangLiu0306
2022-09-20 14:00:04 +08:00
committed by GitHub
parent 47b11c432c
commit 7d1bb71d5d
3 changed files with 197 additions and 1 deletions

View File

@@ -408,7 +408,8 @@ class StrategiesConstructor:
sharding_strategy_attribute = ShardingStrategy(name,
output_sharding_spec,
memory_cost=memory_cost,
resharding_costs=resharding_costs)
resharding_costs=resharding_costs,
input_shardings=tuple(input_sharding_specs))
strategies_vector.append(sharding_strategy_attribute)
self.remove_duplicated_strategy(strategies_vector)