[autoparallel] fix linear logical convert issue (#1857)

This commit is contained in:
YuliangLiu0306
2022-11-10 17:19:22 +08:00
committed by GitHub
parent c2947dadf1
commit 1b494ad73c
4 changed files with 40 additions and 10 deletions

View File

@@ -95,7 +95,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(