mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -21,24 +21,25 @@ try:
|
||||
import pulp
|
||||
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
|
||||
except:
|
||||
warnings.warn(f'please install the pulp')
|
||||
warnings.warn(f"please install the pulp")
|
||||
|
||||
__all___ = ['Solver']
|
||||
__all___ = ["Solver"]
|
||||
|
||||
|
||||
class Solver:
|
||||
|
||||
def __init__(self,
|
||||
graph: Graph,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
cost_graph: CostGraph,
|
||||
graph_analyser: GraphAnalyser = None,
|
||||
memory_budget: float = -1.0,
|
||||
solution_numbers: int = 1,
|
||||
forward_only: bool = False,
|
||||
memory_increasing_coefficient: float = 1.3,
|
||||
verbose=False):
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
cost_graph: CostGraph,
|
||||
graph_analyser: GraphAnalyser = None,
|
||||
memory_budget: float = -1.0,
|
||||
solution_numbers: int = 1,
|
||||
forward_only: bool = False,
|
||||
memory_increasing_coefficient: float = 1.3,
|
||||
verbose=False,
|
||||
):
|
||||
"""
|
||||
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
|
||||
Argument:
|
||||
graph: The computing graph to be optimized.
|
||||
@@ -48,7 +49,7 @@ class Solver:
|
||||
memory_budget: Memory constraint for the solution.
|
||||
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
|
||||
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
|
||||
'''
|
||||
"""
|
||||
self.graph = graph
|
||||
self.strategies_constructor = strategies_constructor
|
||||
self.cost_graph = cost_graph
|
||||
@@ -75,11 +76,11 @@ class Solver:
|
||||
self.verbose = verbose
|
||||
|
||||
def _recover_merged_node_strategy(self):
|
||||
'''
|
||||
"""
|
||||
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
|
||||
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
|
||||
node.
|
||||
'''
|
||||
"""
|
||||
for node_index, node in enumerate(self.nodes):
|
||||
if node.strategies_vector.check_merge():
|
||||
# the merged node has only one input, and its strategies follow the input sharding strategy
|
||||
@@ -98,9 +99,9 @@ class Solver:
|
||||
return node_index_dict
|
||||
|
||||
def _prepare_data_for_solver(self):
|
||||
'''
|
||||
"""
|
||||
Extract information from components for solver.
|
||||
'''
|
||||
"""
|
||||
node_nums = len(self.leaf_strategies)
|
||||
memory_budget = self.memory_budget
|
||||
|
||||
@@ -190,23 +191,40 @@ class Solver:
|
||||
# omit initial value for nodes
|
||||
s_init_np = None
|
||||
|
||||
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose
|
||||
return (
|
||||
node_nums,
|
||||
memory_budget,
|
||||
strategies_len,
|
||||
following_nodes,
|
||||
edge_pairs,
|
||||
alias_set,
|
||||
liveness_set,
|
||||
compute_costs,
|
||||
communication_costs,
|
||||
memory_costs,
|
||||
resharding_costs,
|
||||
alias_convert_costs,
|
||||
s_init_np,
|
||||
self.verbose,
|
||||
)
|
||||
|
||||
def _call_solver_serialized_args(self,
|
||||
node_nums,
|
||||
memory_budget,
|
||||
strategies_len,
|
||||
following_nodes,
|
||||
edge_pairs,
|
||||
alias_set,
|
||||
liveness_set,
|
||||
compute_costs,
|
||||
communication_costs,
|
||||
memory_costs,
|
||||
resharding_costs,
|
||||
alias_convert_costs,
|
||||
s_init_np=None,
|
||||
verbose=True):
|
||||
def _call_solver_serialized_args(
|
||||
self,
|
||||
node_nums,
|
||||
memory_budget,
|
||||
strategies_len,
|
||||
following_nodes,
|
||||
edge_pairs,
|
||||
alias_set,
|
||||
liveness_set,
|
||||
compute_costs,
|
||||
communication_costs,
|
||||
memory_costs,
|
||||
resharding_costs,
|
||||
alias_convert_costs,
|
||||
s_init_np=None,
|
||||
verbose=True,
|
||||
):
|
||||
"""
|
||||
Call the solver with serialized arguments.
|
||||
"""
|
||||
@@ -235,18 +253,18 @@ class Solver:
|
||||
s_follow = following_nodes
|
||||
s_alias = alias_set
|
||||
|
||||
E = edge_pairs.reshape((-1, 2)) # noqa
|
||||
E = edge_pairs.reshape((-1, 2)) # noqa
|
||||
r = []
|
||||
pt = 0
|
||||
edge_set = set()
|
||||
for (i, j) in E:
|
||||
for i, j in E:
|
||||
prod_length = strategies_len[i] * strategies_len[j]
|
||||
|
||||
if (i, j) in edge_set:
|
||||
raise ValueError(f"Duplicated edges: {(i, j)}")
|
||||
|
||||
edge_set.add((i, j))
|
||||
r.append(resharding_costs[pt:pt + prod_length])
|
||||
r.append(resharding_costs[pt : pt + prod_length])
|
||||
pt += prod_length
|
||||
assert pt == len(resharding_costs)
|
||||
|
||||
@@ -268,7 +286,6 @@ class Solver:
|
||||
# L.append(liveness_set[pt:pt + length])
|
||||
# pt += length
|
||||
# assert pt == len(liveness_set)
|
||||
v = []
|
||||
pt = 0
|
||||
|
||||
c = []
|
||||
@@ -277,9 +294,9 @@ class Solver:
|
||||
pt = 0
|
||||
for i in range(node_nums):
|
||||
length = strategies_len[i]
|
||||
c.append(compute_costs[pt:pt + length])
|
||||
d.append(communication_costs[pt:pt + length])
|
||||
m.append(memory_costs[pt:pt + length])
|
||||
c.append(compute_costs[pt : pt + length])
|
||||
d.append(communication_costs[pt : pt + length])
|
||||
m.append(memory_costs[pt : pt + length])
|
||||
pt += length
|
||||
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
|
||||
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
|
||||
@@ -319,7 +336,7 @@ class Solver:
|
||||
e = []
|
||||
num_edges = 0
|
||||
map_edge_to_idx = {}
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
for idx, (i, j) in enumerate(E):
|
||||
if len(s[i]) == 1:
|
||||
e.append(s[j])
|
||||
elif len(s[j]) == 1:
|
||||
@@ -340,7 +357,7 @@ class Solver:
|
||||
######################################
|
||||
if s_init_np is not None:
|
||||
s_init = s_init_np.reshape((-1, 3))
|
||||
for (idx, value, fix) in s_init:
|
||||
for idx, value, fix in s_init:
|
||||
for i in range(len(s[idx])):
|
||||
s[idx][i].setInitialValue(i == value)
|
||||
if fix:
|
||||
@@ -393,7 +410,7 @@ class Solver:
|
||||
|
||||
# (d). specified by `cat="Binary"`
|
||||
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
for idx, (i, j) in enumerate(E):
|
||||
if strategies_len[i] == 1 or strategies_len[j] == 1:
|
||||
continue
|
||||
|
||||
@@ -402,13 +419,13 @@ class Solver:
|
||||
|
||||
# (f)
|
||||
for row in range(len(s[i])):
|
||||
C = len(s[j]) # noqa
|
||||
C = len(s[j]) # noqa
|
||||
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
|
||||
|
||||
# (g)
|
||||
for col in range(len(s[j])):
|
||||
R = len(s[i]) # noqa
|
||||
C = len(s[j]) # noqa
|
||||
R = len(s[i]) # noqa
|
||||
C = len(s[j]) # noqa
|
||||
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
|
||||
|
||||
# (h)
|
||||
@@ -434,7 +451,8 @@ class Solver:
|
||||
msg = verbose
|
||||
time_limit = 600
|
||||
assert "COIN_CMD" in pulp.listSolvers(
|
||||
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
|
||||
onlyAvailable=True
|
||||
), "Please install ILP solvers by 'sudo apt install coinor-cbc'"
|
||||
|
||||
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
|
||||
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
|
||||
@@ -444,13 +462,13 @@ class Solver:
|
||||
objective = pulp.value(prob.objective)
|
||||
objective = float(objective) if objective is not None else -1.0
|
||||
if verbose:
|
||||
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
|
||||
f"Time: {time.time() - tic}")
|
||||
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" f"Time: {time.time() - tic}")
|
||||
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
|
||||
|
||||
if prob.status in [pulp.LpStatusInfeasible]:
|
||||
raise RuntimeError("Cannot run the function under the given memory budget. "
|
||||
"Please increase the memory budget.")
|
||||
raise RuntimeError(
|
||||
"Cannot run the function under the given memory budget. " "Please increase the memory budget."
|
||||
)
|
||||
|
||||
# Get and check results
|
||||
s_val = np.full((node_nums,), -1, dtype=np.int32)
|
||||
@@ -458,7 +476,7 @@ class Solver:
|
||||
s_val[i] = get_non_zero_index(s[i])
|
||||
|
||||
e_val = np.full((len(E),), -1, dtype=np.int32)
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
for idx, (i, j) in enumerate(E):
|
||||
e_val[idx] = get_non_zero_index(e[idx])
|
||||
i_spec_index = e_val[idx] // len(s[j])
|
||||
j_spec_index = e_val[idx] % len(s[j])
|
||||
|
Reference in New Issue
Block a user