[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -3,4 +3,4 @@ from .graph_analysis import GraphAnalyser
from .solver import Solver
from .strategies_constructor import StrategiesConstructor
__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
__all__ = ["GraphAnalyser", "Solver", "StrategiesConstructor", "CostGraph"]

View File

@@ -4,7 +4,7 @@ from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
class CostGraph:
'''
"""
A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
@@ -15,7 +15,7 @@ class CostGraph:
Argument:
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
'''
"""
def __init__(self, leaf_strategies, simplify=True, forward_only=False):
self.leaf_strategies = leaf_strategies
@@ -39,10 +39,10 @@ class CostGraph:
target_node_list.remove(element)
def _build_cost_graph(self):
'''
"""
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
set to node.
'''
"""
self.edge_costs = {}
if self.simplify:
self.merge_pair = []
@@ -84,8 +84,8 @@ class CostGraph:
if _check_tensor_in_node(node._meta_data):
children_nodes.append(node)
setattr(dst_node, 'parents', parent_nodes)
setattr(dst_node, 'children', children_nodes)
setattr(dst_node, "parents", parent_nodes)
setattr(dst_node, "children", children_nodes)
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
@@ -99,7 +99,7 @@ class CostGraph:
return self.edge_costs[(src_node, dst_node)]
def merge_node(self, src_node, dst_node):
'''
"""
To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy
@@ -119,7 +119,7 @@ class CostGraph:
Argument:
src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node.
'''
"""
# build merge_map
merge_map = {}
for src_index, _ in enumerate(src_node.strategies_vector):
@@ -196,7 +196,7 @@ class CostGraph:
if not self.simplify:
return
self.merge_pair.reverse()
for (src_node, dst_node) in self.merge_pair:
for src_node, dst_node in self.merge_pair:
self.merge_node(src_node, dst_node)
self.merge_pair.reverse()
reindexing_following_dict = {}

View File

@@ -7,7 +7,7 @@ from torch.fx.node import Node
from colossalai.fx.passes.utils import get_node_module
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
__all__ = ["LiveVariable", "LiveVariableVector", "LiveStage", "GraphAnalyser"]
@dataclass
@@ -15,6 +15,7 @@ class LiveVariable:
"""
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
"""
name: str
node: Node
is_inplace: bool
@@ -55,6 +56,7 @@ class LiveStage:
"""
LiveStage is a data structure to record the living variables at this current node.
"""
name: str
node: Node
all_live_vars: LiveVariableVector
@@ -62,7 +64,6 @@ class LiveStage:
class GraphAnalyser:
def __init__(self, gm: GraphModule):
self._gm = gm
self._graph = gm.graph
@@ -105,18 +106,18 @@ class GraphAnalyser:
# detect whether the current op is an in-place op
# if it is an in-place op, we would deem it as a duplicate var
is_inplace = False
if node.op == 'call_function':
if node.op == "call_function":
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
if node.kwargs.get('inplace', False):
if node.kwargs.get("inplace", False):
is_inplace = True
elif node.op == 'call_module':
elif node.op == "call_module":
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
module = get_node_module(node)
if getattr(module, 'inplace', False):
if getattr(module, "inplace", False):
is_inplace = True
# add the output var
meta = getattr(node, '_meta_data', None)
getattr(node, "_meta_data", None)
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
if not is_inplace:
unique_live_vars.append(live_var)
@@ -138,10 +139,12 @@ class GraphAnalyser:
# this should be completed if we are able to trace the backward compute graph
# add this stage to liveness dict
stage = LiveStage(name=node.name,
node=node,
all_live_vars=all_live_variables.copy(),
unique_live_vars=unique_live_vars.copy())
stage = LiveStage(
name=node.name,
node=node,
all_live_vars=all_live_variables.copy(),
unique_live_vars=unique_live_vars.copy(),
)
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
replace = False
for index, prev_stage in enumerate(liveness_list):

View File

@@ -21,24 +21,25 @@ try:
import pulp
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
except:
warnings.warn(f'please install the pulp')
warnings.warn(f"please install the pulp")
__all___ = ['Solver']
__all___ = ["Solver"]
class Solver:
def __init__(self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser = None,
memory_budget: float = -1.0,
solution_numbers: int = 1,
forward_only: bool = False,
memory_increasing_coefficient: float = 1.3,
verbose=False):
'''
def __init__(
self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser = None,
memory_budget: float = -1.0,
solution_numbers: int = 1,
forward_only: bool = False,
memory_increasing_coefficient: float = 1.3,
verbose=False,
):
"""
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
@@ -48,7 +49,7 @@ class Solver:
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
'''
"""
self.graph = graph
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
@@ -75,11 +76,11 @@ class Solver:
self.verbose = verbose
def _recover_merged_node_strategy(self):
'''
"""
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
'''
"""
for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
@@ -98,9 +99,9 @@ class Solver:
return node_index_dict
def _prepare_data_for_solver(self):
'''
"""
Extract information from components for solver.
'''
"""
node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget
@@ -190,23 +191,40 @@ class Solver:
# omit initial value for nodes
s_init_np = None
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose
return (
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np,
self.verbose,
)
def _call_solver_serialized_args(self,
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np=None,
verbose=True):
def _call_solver_serialized_args(
self,
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np=None,
verbose=True,
):
"""
Call the solver with serialized arguments.
"""
@@ -235,18 +253,18 @@ class Solver:
s_follow = following_nodes
s_alias = alias_set
E = edge_pairs.reshape((-1, 2)) # noqa
E = edge_pairs.reshape((-1, 2)) # noqa
r = []
pt = 0
edge_set = set()
for (i, j) in E:
for i, j in E:
prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j))
r.append(resharding_costs[pt:pt + prod_length])
r.append(resharding_costs[pt : pt + prod_length])
pt += prod_length
assert pt == len(resharding_costs)
@@ -268,7 +286,6 @@ class Solver:
# L.append(liveness_set[pt:pt + length])
# pt += length
# assert pt == len(liveness_set)
v = []
pt = 0
c = []
@@ -277,9 +294,9 @@ class Solver:
pt = 0
for i in range(node_nums):
length = strategies_len[i]
c.append(compute_costs[pt:pt + length])
d.append(communication_costs[pt:pt + length])
m.append(memory_costs[pt:pt + length])
c.append(compute_costs[pt : pt + length])
d.append(communication_costs[pt : pt + length])
m.append(memory_costs[pt : pt + length])
pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
@@ -319,7 +336,7 @@ class Solver:
e = []
num_edges = 0
map_edge_to_idx = {}
for (idx, (i, j)) in enumerate(E):
for idx, (i, j) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
@@ -340,7 +357,7 @@ class Solver:
######################################
if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3))
for (idx, value, fix) in s_init:
for idx, value, fix in s_init:
for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value)
if fix:
@@ -393,7 +410,7 @@ class Solver:
# (d). specified by `cat="Binary"`
for (idx, (i, j)) in enumerate(E):
for idx, (i, j) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1:
continue
@@ -402,13 +419,13 @@ class Solver:
# (f)
for row in range(len(s[i])):
C = len(s[j]) # noqa
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
# (g)
for col in range(len(s[j])):
R = len(s[i]) # noqa
C = len(s[j]) # noqa
R = len(s[i]) # noqa
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
# (h)
@@ -434,7 +451,8 @@ class Solver:
msg = verbose
time_limit = 600
assert "COIN_CMD" in pulp.listSolvers(
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
onlyAvailable=True
), "Please install ILP solvers by 'sudo apt install coinor-cbc'"
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
@@ -444,13 +462,13 @@ class Solver:
objective = pulp.value(prob.objective)
objective = float(objective) if objective is not None else -1.0
if verbose:
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
f"Time: {time.time() - tic}")
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" f"Time: {time.time() - tic}")
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
if prob.status in [pulp.LpStatusInfeasible]:
raise RuntimeError("Cannot run the function under the given memory budget. "
"Please increase the memory budget.")
raise RuntimeError(
"Cannot run the function under the given memory budget. " "Please increase the memory budget."
)
# Get and check results
s_val = np.full((node_nums,), -1, dtype=np.int32)
@@ -458,7 +476,7 @@ class Solver:
s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E),), -1, dtype=np.int32)
for (idx, (i, j)) in enumerate(E):
for idx, (i, j) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j])

View File

@@ -1,11 +1,5 @@
import builtins
import math
import operator
from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx import Graph, Node
from torch.fx import Graph
from colossalai.auto_parallel.tensor_shard.node_handler import (
GetattrHandler,
@@ -14,13 +8,12 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
operator_registry,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.device.device_mesh import DeviceMesh
from ..options import DataloaderOption, SolverOptions
__all__ = ['StrategiesConstructor']
__all__ = ["StrategiesConstructor"]
class StrategiesConstructor:
@@ -35,7 +28,7 @@ class StrategiesConstructor:
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
assert graph.owning_module is not None, "The given graph is not associated with a owning_module"
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
@@ -46,11 +39,11 @@ class StrategiesConstructor:
self.alias_set = None
def remove_duplicated_strategy(self, strategies_vector):
'''
"""
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
Note that this operation is in-place.
'''
"""
name_checklist = []
remove_list = []
for strategy in strategies_vector:
@@ -62,7 +55,6 @@ class StrategiesConstructor:
strategies_vector.remove(strategy)
def generate_alias_set(self):
node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies]
common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10)
@@ -83,7 +75,7 @@ class StrategiesConstructor:
"""
def _check_no_strategy_for_node(node):
if node.op in ('placeholder', 'get_attr', 'output'):
if node.op in ("placeholder", "get_attr", "output"):
return False
def _check_no_strategy_for_data(data):
@@ -102,83 +94,93 @@ class StrategiesConstructor:
if _check_no_strategy_for_node(node):
self.no_strategy_nodes.append(node)
pass
# placeholder node
elif node.op == 'placeholder':
elif node.op == "placeholder":
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
placeholder_option = 'distributed'
placeholder_option = "distributed"
else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
placeholder_option = 'replicated'
placeholder_handler = PlaceholderHandler(node,
self.device_mesh,
strategies_vector,
placeholder_option=placeholder_option)
assert (
self.solver_options.dataloader_option == DataloaderOption.REPLICATED
), f"placeholder_option {self.solver_options.dataloader_option} is not supported"
placeholder_option = "replicated"
placeholder_handler = PlaceholderHandler(
node, self.device_mesh, strategies_vector, placeholder_option=placeholder_option
)
placeholder_handler.register_strategy()
# get_attr node
elif node.op == 'get_attr':
getattr_handler = GetattrHandler(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
elif node.op == "get_attr":
getattr_handler = GetattrHandler(
node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference,
)
getattr_handler.register_strategy()
# call_module node
elif node.op == 'call_module':
elif node.op == "call_module":
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
handler = operator_registry.get(submod_type)(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
handler = operator_registry.get(submod_type)(
node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference,
)
handler.register_strategy()
# attach strategies_info to node
if hasattr(handler, 'strategies_info'):
setattr(node, 'strategies_info', handler.strategies_info)
if hasattr(handler, "strategies_info"):
setattr(node, "strategies_info", handler.strategies_info)
# call_function node
elif node.op == 'call_function':
elif node.op == "call_function":
target = node.target
handler = operator_registry.get(target)(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
handler = operator_registry.get(target)(
node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference,
)
handler.register_strategy()
# attach strategies_info to node
if hasattr(handler, 'strategies_info'):
setattr(node, 'strategies_info', handler.strategies_info)
if hasattr(handler, "strategies_info"):
setattr(node, "strategies_info", handler.strategies_info)
# call_method node
elif node.op == 'call_method':
elif node.op == "call_method":
method = getattr(node.args[0]._meta_data.__class__, node.target)
handler = operator_registry.get(method)(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
handler = operator_registry.get(method)(
node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference,
)
handler.register_strategy()
# attach strategies_info to node
if hasattr(handler, 'strategies_info'):
setattr(node, 'strategies_info', handler.strategies_info)
if hasattr(handler, "strategies_info"):
setattr(node, "strategies_info", handler.strategies_info)
# output node
elif node.op == 'output':
elif node.op == "output":
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
output_option = 'distributed'
output_option = "distributed"
else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
output_option = 'replicated'
assert (
self.solver_options.dataloader_option == DataloaderOption.REPLICATED
), f"placeholder_option {self.solver_options.dataloader_option} is not supported"
output_option = "replicated"
output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
output_handler.register_strategy()
self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
setattr(node, "strategies_vector", strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector