[autoparallel] apply repeat block to reduce solving time (#2912)

This commit is contained in:
YuliangLiu0306
2023-02-28 11:03:30 +08:00
committed by GitHub
parent a848091141
commit 197d0bf4ed
6 changed files with 57 additions and 28 deletions

View File

@@ -112,11 +112,13 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
'''
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# graph_analyser = GraphAnalyser(gm)
# liveness_list = graph_analyser.liveness_analysis()
cost_graph = CostGraph(strategy_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategy_constructor, cost_graph, graph_analyser, memory_budget=memory_budget)
solver = Solver(gm.graph, strategy_constructor, cost_graph, memory_budget=memory_budget)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])