[autoparallel] apply repeat block to reduce solving time (#2912)

This commit is contained in:
YuliangLiu0306
2023-02-28 11:03:30 +08:00
committed by GitHub
parent a848091141
commit 197d0bf4ed
6 changed files with 57 additions and 28 deletions

View File

@@ -112,11 +112,13 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
'''
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# graph_analyser = GraphAnalyser(gm)
# liveness_list = graph_analyser.liveness_analysis()
cost_graph = CostGraph(strategy_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategy_constructor, cost_graph, graph_analyser, memory_budget=memory_budget)
solver = Solver(gm.graph, strategy_constructor, cost_graph, memory_budget=memory_budget)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])

View File

@@ -32,7 +32,7 @@ class Solver:
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser,
graph_analyser: GraphAnalyser = None,
memory_budget: float = -1.0,
solution_numbers: int = 1,
forward_only: bool = False,
@@ -63,7 +63,10 @@ class Solver:
self.memory_increasing_coefficient = memory_increasing_coefficient
else:
self.memory_increasing_coefficient = 1
self.liveness_list = self.graph_analyser.liveness_analysis()
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# self.liveness_list = self.graph_analyser.liveness_analysis()
self.liveness_list = self.nodes
self.node_index_dict = self._generate_node_index_dict()
# The last solution vector of auto sharding.
self.last_s_val = None
@@ -140,7 +143,7 @@ class Solver:
liveness_set = self.liveness_list
# omit alias_set now
alias_set = None
alias_set = self.strategies_constructor.alias_set
alias_convert_costs = None
# prepare compute_costs, communication_costs and memory_costs
@@ -230,6 +233,7 @@ class Solver:
# 0. Unpack flatten numpy arrays
s_follow = following_nodes
s_alias = alias_set
E = edge_pairs.reshape((-1, 2)) # noqa
r = []
@@ -294,8 +298,11 @@ class Solver:
if strategies_len[i] == 1:
s.append([1])
else:
num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
if i not in s_alias:
num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
else:
s.append(s[s_alias[i]])
else:
if s_follow[i] < len(s):
s.append(s[s_follow[i]])
@@ -311,15 +318,20 @@ class Solver:
#############################
e = []
num_edges = 0
map_edge_to_idx = {}
for (idx, (i, j)) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
e.append(s[i])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
if i in s_alias and j in s_alias and (s_alias[i], s_alias[j]) in map_edge_to_idx:
e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])
map_edge_to_idx[(i, j)] = idx
for element in s:
assert len(element) > 0
# 2. Set initial value
@@ -371,13 +383,12 @@ class Solver:
# compute memory consumption with liveness set #
#################################################
if memory_budget > 0:
for liveness_stage in liveness_set:
mem = 0
for live_variable in liveness_stage.unique_live_vars:
if live_variable.node not in self.node_index_dict:
continue
node_index = self.node_index_dict[live_variable.node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
mem = 0
for node in liveness_set:
if node not in self.node_index_dict:
continue
node_index = self.node_index_dict[node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget
# (d). specified by `cat="Binary"`

View File

@@ -15,6 +15,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.device.device_mesh import DeviceMesh
from ..options import DataloaderOption, SolverOptions
@@ -42,6 +43,7 @@ class StrategiesConstructor:
self.strategy_map = {}
self.solver_options = solver_options
self.no_strategy_nodes = []
self.alias_set = None
def remove_duplicated_strategy(self, strategies_vector):
'''
@@ -59,6 +61,22 @@ class StrategiesConstructor:
for strategy in remove_list:
strategies_vector.remove(strategy)
def generate_alias_set(self):
node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies]
common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10)
repeat_block_nums = len(common_blocks)
alias_set = {}
if repeat_block_nums == 0:
return alias_set
for index, common_node in enumerate(common_blocks[0]):
for i in range(1, repeat_block_nums):
alias_set[node_list.index(common_blocks[i][index])] = node_list.index(common_node)
return alias_set
def build_strategies_and_cost(self):
"""
This method is to build the strategy vector for each node in the computation graph.
@@ -175,3 +193,6 @@ class StrategiesConstructor:
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)
alias_set = self.generate_alias_set()
self.alias_set = alias_set