mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,4 +3,4 @@ from .graph_analysis import GraphAnalyser
|
||||
from .solver import Solver
|
||||
from .strategies_constructor import StrategiesConstructor
|
||||
|
||||
__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
|
||||
__all__ = ["GraphAnalyser", "Solver", "StrategiesConstructor", "CostGraph"]
|
||||
|
@@ -4,7 +4,7 @@ from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
|
||||
|
||||
|
||||
class CostGraph:
|
||||
'''
|
||||
"""
|
||||
A graph data structure to simplify the edge cost graph. It has two main functions:
|
||||
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
|
||||
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
|
||||
@@ -15,7 +15,7 @@ class CostGraph:
|
||||
Argument:
|
||||
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
|
||||
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, leaf_strategies, simplify=True, forward_only=False):
|
||||
self.leaf_strategies = leaf_strategies
|
||||
@@ -39,10 +39,10 @@ class CostGraph:
|
||||
target_node_list.remove(element)
|
||||
|
||||
def _build_cost_graph(self):
|
||||
'''
|
||||
"""
|
||||
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
|
||||
set to node.
|
||||
'''
|
||||
"""
|
||||
self.edge_costs = {}
|
||||
if self.simplify:
|
||||
self.merge_pair = []
|
||||
@@ -84,8 +84,8 @@ class CostGraph:
|
||||
if _check_tensor_in_node(node._meta_data):
|
||||
children_nodes.append(node)
|
||||
|
||||
setattr(dst_node, 'parents', parent_nodes)
|
||||
setattr(dst_node, 'children', children_nodes)
|
||||
setattr(dst_node, "parents", parent_nodes)
|
||||
setattr(dst_node, "children", children_nodes)
|
||||
|
||||
if self.simplify and strategies_vector.check_merge():
|
||||
for followed_node in strategies_vector.predecessor_nodes:
|
||||
@@ -99,7 +99,7 @@ class CostGraph:
|
||||
return self.edge_costs[(src_node, dst_node)]
|
||||
|
||||
def merge_node(self, src_node, dst_node):
|
||||
'''
|
||||
"""
|
||||
To merge dst_node into src_node, we need to do it in following steps:
|
||||
|
||||
1. For each strategy in dst_node, we need to pick an appropriate strategy
|
||||
@@ -119,7 +119,7 @@ class CostGraph:
|
||||
Argument:
|
||||
src_node(Node): The node will be merged into dst_node.
|
||||
dst_node(Node): The node to integrate src_node.
|
||||
'''
|
||||
"""
|
||||
# build merge_map
|
||||
merge_map = {}
|
||||
for src_index, _ in enumerate(src_node.strategies_vector):
|
||||
@@ -196,7 +196,7 @@ class CostGraph:
|
||||
if not self.simplify:
|
||||
return
|
||||
self.merge_pair.reverse()
|
||||
for (src_node, dst_node) in self.merge_pair:
|
||||
for src_node, dst_node in self.merge_pair:
|
||||
self.merge_node(src_node, dst_node)
|
||||
self.merge_pair.reverse()
|
||||
reindexing_following_dict = {}
|
||||
|
@@ -7,7 +7,7 @@ from torch.fx.node import Node
|
||||
|
||||
from colossalai.fx.passes.utils import get_node_module
|
||||
|
||||
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
|
||||
__all__ = ["LiveVariable", "LiveVariableVector", "LiveStage", "GraphAnalyser"]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -15,6 +15,7 @@ class LiveVariable:
|
||||
"""
|
||||
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
|
||||
"""
|
||||
|
||||
name: str
|
||||
node: Node
|
||||
is_inplace: bool
|
||||
@@ -55,6 +56,7 @@ class LiveStage:
|
||||
"""
|
||||
LiveStage is a data structure to record the living variables at this current node.
|
||||
"""
|
||||
|
||||
name: str
|
||||
node: Node
|
||||
all_live_vars: LiveVariableVector
|
||||
@@ -62,7 +64,6 @@ class LiveStage:
|
||||
|
||||
|
||||
class GraphAnalyser:
|
||||
|
||||
def __init__(self, gm: GraphModule):
|
||||
self._gm = gm
|
||||
self._graph = gm.graph
|
||||
@@ -105,18 +106,18 @@ class GraphAnalyser:
|
||||
# detect whether the current op is an in-place op
|
||||
# if it is an in-place op, we would deem it as a duplicate var
|
||||
is_inplace = False
|
||||
if node.op == 'call_function':
|
||||
if node.op == "call_function":
|
||||
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
|
||||
if node.kwargs.get('inplace', False):
|
||||
if node.kwargs.get("inplace", False):
|
||||
is_inplace = True
|
||||
elif node.op == 'call_module':
|
||||
elif node.op == "call_module":
|
||||
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
|
||||
module = get_node_module(node)
|
||||
if getattr(module, 'inplace', False):
|
||||
if getattr(module, "inplace", False):
|
||||
is_inplace = True
|
||||
|
||||
# add the output var
|
||||
meta = getattr(node, '_meta_data', None)
|
||||
getattr(node, "_meta_data", None)
|
||||
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
|
||||
if not is_inplace:
|
||||
unique_live_vars.append(live_var)
|
||||
@@ -138,10 +139,12 @@ class GraphAnalyser:
|
||||
# this should be completed if we are able to trace the backward compute graph
|
||||
|
||||
# add this stage to liveness dict
|
||||
stage = LiveStage(name=node.name,
|
||||
node=node,
|
||||
all_live_vars=all_live_variables.copy(),
|
||||
unique_live_vars=unique_live_vars.copy())
|
||||
stage = LiveStage(
|
||||
name=node.name,
|
||||
node=node,
|
||||
all_live_vars=all_live_variables.copy(),
|
||||
unique_live_vars=unique_live_vars.copy(),
|
||||
)
|
||||
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
|
||||
replace = False
|
||||
for index, prev_stage in enumerate(liveness_list):
|
||||
|
@@ -21,24 +21,25 @@ try:
|
||||
import pulp
|
||||
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
|
||||
except:
|
||||
warnings.warn(f'please install the pulp')
|
||||
warnings.warn(f"please install the pulp")
|
||||
|
||||
__all___ = ['Solver']
|
||||
__all___ = ["Solver"]
|
||||
|
||||
|
||||
class Solver:
|
||||
|
||||
def __init__(self,
|
||||
graph: Graph,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
cost_graph: CostGraph,
|
||||
graph_analyser: GraphAnalyser = None,
|
||||
memory_budget: float = -1.0,
|
||||
solution_numbers: int = 1,
|
||||
forward_only: bool = False,
|
||||
memory_increasing_coefficient: float = 1.3,
|
||||
verbose=False):
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
cost_graph: CostGraph,
|
||||
graph_analyser: GraphAnalyser = None,
|
||||
memory_budget: float = -1.0,
|
||||
solution_numbers: int = 1,
|
||||
forward_only: bool = False,
|
||||
memory_increasing_coefficient: float = 1.3,
|
||||
verbose=False,
|
||||
):
|
||||
"""
|
||||
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
|
||||
Argument:
|
||||
graph: The computing graph to be optimized.
|
||||
@@ -48,7 +49,7 @@ class Solver:
|
||||
memory_budget: Memory constraint for the solution.
|
||||
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
|
||||
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
|
||||
'''
|
||||
"""
|
||||
self.graph = graph
|
||||
self.strategies_constructor = strategies_constructor
|
||||
self.cost_graph = cost_graph
|
||||
@@ -75,11 +76,11 @@ class Solver:
|
||||
self.verbose = verbose
|
||||
|
||||
def _recover_merged_node_strategy(self):
|
||||
'''
|
||||
"""
|
||||
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
|
||||
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
|
||||
node.
|
||||
'''
|
||||
"""
|
||||
for node_index, node in enumerate(self.nodes):
|
||||
if node.strategies_vector.check_merge():
|
||||
# the merged node has only one input, and its strategies follow the input sharding strategy
|
||||
@@ -98,9 +99,9 @@ class Solver:
|
||||
return node_index_dict
|
||||
|
||||
def _prepare_data_for_solver(self):
|
||||
'''
|
||||
"""
|
||||
Extract information from components for solver.
|
||||
'''
|
||||
"""
|
||||
node_nums = len(self.leaf_strategies)
|
||||
memory_budget = self.memory_budget
|
||||
|
||||
@@ -190,23 +191,40 @@ class Solver:
|
||||
# omit initial value for nodes
|
||||
s_init_np = None
|
||||
|
||||
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose
|
||||
return (
|
||||
node_nums,
|
||||
memory_budget,
|
||||
strategies_len,
|
||||
following_nodes,
|
||||
edge_pairs,
|
||||
alias_set,
|
||||
liveness_set,
|
||||
compute_costs,
|
||||
communication_costs,
|
||||
memory_costs,
|
||||
resharding_costs,
|
||||
alias_convert_costs,
|
||||
s_init_np,
|
||||
self.verbose,
|
||||
)
|
||||
|
||||
def _call_solver_serialized_args(self,
|
||||
node_nums,
|
||||
memory_budget,
|
||||
strategies_len,
|
||||
following_nodes,
|
||||
edge_pairs,
|
||||
alias_set,
|
||||
liveness_set,
|
||||
compute_costs,
|
||||
communication_costs,
|
||||
memory_costs,
|
||||
resharding_costs,
|
||||
alias_convert_costs,
|
||||
s_init_np=None,
|
||||
verbose=True):
|
||||
def _call_solver_serialized_args(
|
||||
self,
|
||||
node_nums,
|
||||
memory_budget,
|
||||
strategies_len,
|
||||
following_nodes,
|
||||
edge_pairs,
|
||||
alias_set,
|
||||
liveness_set,
|
||||
compute_costs,
|
||||
communication_costs,
|
||||
memory_costs,
|
||||
resharding_costs,
|
||||
alias_convert_costs,
|
||||
s_init_np=None,
|
||||
verbose=True,
|
||||
):
|
||||
"""
|
||||
Call the solver with serialized arguments.
|
||||
"""
|
||||
@@ -235,18 +253,18 @@ class Solver:
|
||||
s_follow = following_nodes
|
||||
s_alias = alias_set
|
||||
|
||||
E = edge_pairs.reshape((-1, 2)) # noqa
|
||||
E = edge_pairs.reshape((-1, 2)) # noqa
|
||||
r = []
|
||||
pt = 0
|
||||
edge_set = set()
|
||||
for (i, j) in E:
|
||||
for i, j in E:
|
||||
prod_length = strategies_len[i] * strategies_len[j]
|
||||
|
||||
if (i, j) in edge_set:
|
||||
raise ValueError(f"Duplicated edges: {(i, j)}")
|
||||
|
||||
edge_set.add((i, j))
|
||||
r.append(resharding_costs[pt:pt + prod_length])
|
||||
r.append(resharding_costs[pt : pt + prod_length])
|
||||
pt += prod_length
|
||||
assert pt == len(resharding_costs)
|
||||
|
||||
@@ -268,7 +286,6 @@ class Solver:
|
||||
# L.append(liveness_set[pt:pt + length])
|
||||
# pt += length
|
||||
# assert pt == len(liveness_set)
|
||||
v = []
|
||||
pt = 0
|
||||
|
||||
c = []
|
||||
@@ -277,9 +294,9 @@ class Solver:
|
||||
pt = 0
|
||||
for i in range(node_nums):
|
||||
length = strategies_len[i]
|
||||
c.append(compute_costs[pt:pt + length])
|
||||
d.append(communication_costs[pt:pt + length])
|
||||
m.append(memory_costs[pt:pt + length])
|
||||
c.append(compute_costs[pt : pt + length])
|
||||
d.append(communication_costs[pt : pt + length])
|
||||
m.append(memory_costs[pt : pt + length])
|
||||
pt += length
|
||||
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
|
||||
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
|
||||
@@ -319,7 +336,7 @@ class Solver:
|
||||
e = []
|
||||
num_edges = 0
|
||||
map_edge_to_idx = {}
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
for idx, (i, j) in enumerate(E):
|
||||
if len(s[i]) == 1:
|
||||
e.append(s[j])
|
||||
elif len(s[j]) == 1:
|
||||
@@ -340,7 +357,7 @@ class Solver:
|
||||
######################################
|
||||
if s_init_np is not None:
|
||||
s_init = s_init_np.reshape((-1, 3))
|
||||
for (idx, value, fix) in s_init:
|
||||
for idx, value, fix in s_init:
|
||||
for i in range(len(s[idx])):
|
||||
s[idx][i].setInitialValue(i == value)
|
||||
if fix:
|
||||
@@ -393,7 +410,7 @@ class Solver:
|
||||
|
||||
# (d). specified by `cat="Binary"`
|
||||
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
for idx, (i, j) in enumerate(E):
|
||||
if strategies_len[i] == 1 or strategies_len[j] == 1:
|
||||
continue
|
||||
|
||||
@@ -402,13 +419,13 @@ class Solver:
|
||||
|
||||
# (f)
|
||||
for row in range(len(s[i])):
|
||||
C = len(s[j]) # noqa
|
||||
C = len(s[j]) # noqa
|
||||
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
|
||||
|
||||
# (g)
|
||||
for col in range(len(s[j])):
|
||||
R = len(s[i]) # noqa
|
||||
C = len(s[j]) # noqa
|
||||
R = len(s[i]) # noqa
|
||||
C = len(s[j]) # noqa
|
||||
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
|
||||
|
||||
# (h)
|
||||
@@ -434,7 +451,8 @@ class Solver:
|
||||
msg = verbose
|
||||
time_limit = 600
|
||||
assert "COIN_CMD" in pulp.listSolvers(
|
||||
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
|
||||
onlyAvailable=True
|
||||
), "Please install ILP solvers by 'sudo apt install coinor-cbc'"
|
||||
|
||||
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
|
||||
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
|
||||
@@ -444,13 +462,13 @@ class Solver:
|
||||
objective = pulp.value(prob.objective)
|
||||
objective = float(objective) if objective is not None else -1.0
|
||||
if verbose:
|
||||
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
|
||||
f"Time: {time.time() - tic}")
|
||||
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" f"Time: {time.time() - tic}")
|
||||
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
|
||||
|
||||
if prob.status in [pulp.LpStatusInfeasible]:
|
||||
raise RuntimeError("Cannot run the function under the given memory budget. "
|
||||
"Please increase the memory budget.")
|
||||
raise RuntimeError(
|
||||
"Cannot run the function under the given memory budget. " "Please increase the memory budget."
|
||||
)
|
||||
|
||||
# Get and check results
|
||||
s_val = np.full((node_nums,), -1, dtype=np.int32)
|
||||
@@ -458,7 +476,7 @@ class Solver:
|
||||
s_val[i] = get_non_zero_index(s[i])
|
||||
|
||||
e_val = np.full((len(E),), -1, dtype=np.int32)
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
for idx, (i, j) in enumerate(E):
|
||||
e_val[idx] = get_non_zero_index(e[idx])
|
||||
i_spec_index = e_val[idx] // len(s[j])
|
||||
j_spec_index = e_val[idx] % len(s[j])
|
||||
|
@@ -1,11 +1,5 @@
|
||||
import builtins
|
||||
import math
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
from torch.fx import Graph
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import (
|
||||
GetattrHandler,
|
||||
@@ -14,13 +8,12 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
|
||||
operator_registry,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
|
||||
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from ..options import DataloaderOption, SolverOptions
|
||||
|
||||
__all__ = ['StrategiesConstructor']
|
||||
__all__ = ["StrategiesConstructor"]
|
||||
|
||||
|
||||
class StrategiesConstructor:
|
||||
@@ -35,7 +28,7 @@ class StrategiesConstructor:
|
||||
|
||||
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
|
||||
self.graph = graph
|
||||
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
|
||||
assert graph.owning_module is not None, "The given graph is not associated with a owning_module"
|
||||
self.root_module = self.graph.owning_module
|
||||
self.nodes = list(graph.nodes)
|
||||
self.device_mesh = device_mesh
|
||||
@@ -46,11 +39,11 @@ class StrategiesConstructor:
|
||||
self.alias_set = None
|
||||
|
||||
def remove_duplicated_strategy(self, strategies_vector):
|
||||
'''
|
||||
"""
|
||||
In build_strategies_and_cost method, we may produce some duplicated strategies.
|
||||
In this method, we will remove the duplicated strategies depending on the strategies name.
|
||||
Note that this operation is in-place.
|
||||
'''
|
||||
"""
|
||||
name_checklist = []
|
||||
remove_list = []
|
||||
for strategy in strategies_vector:
|
||||
@@ -62,7 +55,6 @@ class StrategiesConstructor:
|
||||
strategies_vector.remove(strategy)
|
||||
|
||||
def generate_alias_set(self):
|
||||
|
||||
node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies]
|
||||
common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10)
|
||||
|
||||
@@ -83,7 +75,7 @@ class StrategiesConstructor:
|
||||
"""
|
||||
|
||||
def _check_no_strategy_for_node(node):
|
||||
if node.op in ('placeholder', 'get_attr', 'output'):
|
||||
if node.op in ("placeholder", "get_attr", "output"):
|
||||
return False
|
||||
|
||||
def _check_no_strategy_for_data(data):
|
||||
@@ -102,83 +94,93 @@ class StrategiesConstructor:
|
||||
|
||||
if _check_no_strategy_for_node(node):
|
||||
self.no_strategy_nodes.append(node)
|
||||
pass
|
||||
|
||||
# placeholder node
|
||||
elif node.op == 'placeholder':
|
||||
elif node.op == "placeholder":
|
||||
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
|
||||
placeholder_option = 'distributed'
|
||||
placeholder_option = "distributed"
|
||||
else:
|
||||
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
|
||||
placeholder_option = 'replicated'
|
||||
placeholder_handler = PlaceholderHandler(node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
placeholder_option=placeholder_option)
|
||||
assert (
|
||||
self.solver_options.dataloader_option == DataloaderOption.REPLICATED
|
||||
), f"placeholder_option {self.solver_options.dataloader_option} is not supported"
|
||||
placeholder_option = "replicated"
|
||||
placeholder_handler = PlaceholderHandler(
|
||||
node, self.device_mesh, strategies_vector, placeholder_option=placeholder_option
|
||||
)
|
||||
placeholder_handler.register_strategy()
|
||||
|
||||
# get_attr node
|
||||
elif node.op == 'get_attr':
|
||||
getattr_handler = GetattrHandler(node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
elif node.op == "get_attr":
|
||||
getattr_handler = GetattrHandler(
|
||||
node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference,
|
||||
)
|
||||
getattr_handler.register_strategy()
|
||||
|
||||
# call_module node
|
||||
elif node.op == 'call_module':
|
||||
elif node.op == "call_module":
|
||||
target = node.target
|
||||
submod = self.root_module.get_submodule(target)
|
||||
submod_type = type(submod)
|
||||
handler = operator_registry.get(submod_type)(node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
handler = operator_registry.get(submod_type)(
|
||||
node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference,
|
||||
)
|
||||
handler.register_strategy()
|
||||
# attach strategies_info to node
|
||||
if hasattr(handler, 'strategies_info'):
|
||||
setattr(node, 'strategies_info', handler.strategies_info)
|
||||
if hasattr(handler, "strategies_info"):
|
||||
setattr(node, "strategies_info", handler.strategies_info)
|
||||
|
||||
# call_function node
|
||||
elif node.op == 'call_function':
|
||||
elif node.op == "call_function":
|
||||
target = node.target
|
||||
handler = operator_registry.get(target)(node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
handler = operator_registry.get(target)(
|
||||
node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference,
|
||||
)
|
||||
handler.register_strategy()
|
||||
# attach strategies_info to node
|
||||
if hasattr(handler, 'strategies_info'):
|
||||
setattr(node, 'strategies_info', handler.strategies_info)
|
||||
if hasattr(handler, "strategies_info"):
|
||||
setattr(node, "strategies_info", handler.strategies_info)
|
||||
|
||||
# call_method node
|
||||
elif node.op == 'call_method':
|
||||
elif node.op == "call_method":
|
||||
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||
handler = operator_registry.get(method)(node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
handler = operator_registry.get(method)(
|
||||
node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference,
|
||||
)
|
||||
handler.register_strategy()
|
||||
# attach strategies_info to node
|
||||
if hasattr(handler, 'strategies_info'):
|
||||
setattr(node, 'strategies_info', handler.strategies_info)
|
||||
if hasattr(handler, "strategies_info"):
|
||||
setattr(node, "strategies_info", handler.strategies_info)
|
||||
|
||||
# output node
|
||||
elif node.op == 'output':
|
||||
elif node.op == "output":
|
||||
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
|
||||
output_option = 'distributed'
|
||||
output_option = "distributed"
|
||||
else:
|
||||
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
|
||||
output_option = 'replicated'
|
||||
assert (
|
||||
self.solver_options.dataloader_option == DataloaderOption.REPLICATED
|
||||
), f"placeholder_option {self.solver_options.dataloader_option} is not supported"
|
||||
output_option = "replicated"
|
||||
output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
|
||||
output_handler.register_strategy()
|
||||
|
||||
self.remove_duplicated_strategy(strategies_vector)
|
||||
setattr(node, 'strategies_vector', strategies_vector)
|
||||
setattr(node, "strategies_vector", strategies_vector)
|
||||
self.leaf_strategies.append(strategies_vector)
|
||||
self.strategy_map[node] = strategies_vector
|
||||
|
||||
|
Reference in New Issue
Block a user