mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[autoparallel] apply repeat block to reduce solving time (#2912)
This commit is contained in:
@@ -32,7 +32,7 @@ class Solver:
|
||||
graph: Graph,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
cost_graph: CostGraph,
|
||||
graph_analyser: GraphAnalyser,
|
||||
graph_analyser: GraphAnalyser = None,
|
||||
memory_budget: float = -1.0,
|
||||
solution_numbers: int = 1,
|
||||
forward_only: bool = False,
|
||||
@@ -63,7 +63,10 @@ class Solver:
|
||||
self.memory_increasing_coefficient = memory_increasing_coefficient
|
||||
else:
|
||||
self.memory_increasing_coefficient = 1
|
||||
self.liveness_list = self.graph_analyser.liveness_analysis()
|
||||
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
|
||||
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
|
||||
# self.liveness_list = self.graph_analyser.liveness_analysis()
|
||||
self.liveness_list = self.nodes
|
||||
self.node_index_dict = self._generate_node_index_dict()
|
||||
# The last solution vector of auto sharding.
|
||||
self.last_s_val = None
|
||||
@@ -140,7 +143,7 @@ class Solver:
|
||||
liveness_set = self.liveness_list
|
||||
|
||||
# omit alias_set now
|
||||
alias_set = None
|
||||
alias_set = self.strategies_constructor.alias_set
|
||||
alias_convert_costs = None
|
||||
|
||||
# prepare compute_costs, communication_costs and memory_costs
|
||||
@@ -230,6 +233,7 @@ class Solver:
|
||||
|
||||
# 0. Unpack flatten numpy arrays
|
||||
s_follow = following_nodes
|
||||
s_alias = alias_set
|
||||
|
||||
E = edge_pairs.reshape((-1, 2)) # noqa
|
||||
r = []
|
||||
@@ -294,8 +298,11 @@ class Solver:
|
||||
if strategies_len[i] == 1:
|
||||
s.append([1])
|
||||
else:
|
||||
num_nodes += 1
|
||||
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
|
||||
if i not in s_alias:
|
||||
num_nodes += 1
|
||||
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
|
||||
else:
|
||||
s.append(s[s_alias[i]])
|
||||
else:
|
||||
if s_follow[i] < len(s):
|
||||
s.append(s[s_follow[i]])
|
||||
@@ -311,15 +318,20 @@ class Solver:
|
||||
#############################
|
||||
e = []
|
||||
num_edges = 0
|
||||
map_edge_to_idx = {}
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
if len(s[i]) == 1:
|
||||
e.append(s[j])
|
||||
elif len(s[j]) == 1:
|
||||
e.append(s[i])
|
||||
else:
|
||||
num_edges += 1
|
||||
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
|
||||
if i in s_alias and j in s_alias and (s_alias[i], s_alias[j]) in map_edge_to_idx:
|
||||
e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]])
|
||||
else:
|
||||
num_edges += 1
|
||||
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
|
||||
assert len(e[idx]) == len(r[idx])
|
||||
map_edge_to_idx[(i, j)] = idx
|
||||
for element in s:
|
||||
assert len(element) > 0
|
||||
# 2. Set initial value
|
||||
@@ -371,13 +383,12 @@ class Solver:
|
||||
# compute memory consumption with liveness set #
|
||||
#################################################
|
||||
if memory_budget > 0:
|
||||
for liveness_stage in liveness_set:
|
||||
mem = 0
|
||||
for live_variable in liveness_stage.unique_live_vars:
|
||||
if live_variable.node not in self.node_index_dict:
|
||||
continue
|
||||
node_index = self.node_index_dict[live_variable.node]
|
||||
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
|
||||
mem = 0
|
||||
for node in liveness_set:
|
||||
if node not in self.node_index_dict:
|
||||
continue
|
||||
node_index = self.node_index_dict[node]
|
||||
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
|
||||
prob += mem <= memory_budget
|
||||
|
||||
# (d). specified by `cat="Binary"`
|
||||
|
Reference in New Issue
Block a user