diff --git a/colossalai/auto_parallel/solver/op_handler/__init__.py b/colossalai/auto_parallel/solver/op_handler/__init__.py index a0d570325..486a8fe88 100644 --- a/colossalai/auto_parallel/solver/op_handler/__init__.py +++ b/colossalai/auto_parallel/solver/op_handler/__init__.py @@ -6,8 +6,9 @@ from .reshape_handler import ReshapeHandler from .bcast_op_handler import BcastOpHandler from .embedding_handler import EmbeddingHandler from .unary_elementwise_handler import UnaryElementwiseHandler +from .dot_handler_v2 import LinearFunctionHandler, LinearModuleHandler __all__ = [ 'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler', - 'UnaryElementwiseHandler', 'EmbeddingHandler' + 'UnaryElementwiseHandler', 'EmbeddingHandler', 'LinearFunctionHandler', 'LinearModuleHandler' ] diff --git a/colossalai/auto_parallel/solver/op_handler/registry.py b/colossalai/auto_parallel/solver/op_handler/registry.py index 51855e4bf..6bed842d4 100644 --- a/colossalai/auto_parallel/solver/op_handler/registry.py +++ b/colossalai/auto_parallel/solver/op_handler/registry.py @@ -14,7 +14,7 @@ class Registry: return wrapper def get(self, source): - assert source in self.store + assert source in self.store, f'{source} not found in the {self.name} registry' target = self.store[source] return target diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index 4df256568..4c1a390ce 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -49,9 +49,10 @@ class OperationDataType(Enum): """ An operation can come from the argument list of an operator or the parameter list of a module. """ - ARG = 0 - PARAM = 1 - OUTPUT = 2 + INPUT = 0 + ARG = 1 + PARAM = 2 + OUTPUT = 3 @dataclass diff --git a/colossalai/auto_parallel/solver/strategies_constructor.py b/colossalai/auto_parallel/solver/strategies_constructor.py index ed3da6c8c..6eb843eba 100644 --- a/colossalai/auto_parallel/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/solver/strategies_constructor.py @@ -4,6 +4,7 @@ from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerN from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.auto_parallel.solver.op_handler.registry import operator_registry from .options import SolverOptions from . import ShardingStrategy, StrategiesVector from .op_handler import * @@ -16,6 +17,8 @@ from typing import Dict, List from ._utils import generate_sharding_spec, generate_resharding_costs import builtins +__all__ = ['StrategiesConstructor', 'StrategiesConstructor_V2'] + class StrategiesConstructor: """ @@ -49,6 +52,7 @@ class StrategiesConstructor: name_checklist.append(strategy.name) else: remove_list.append(strategy) + for strategy in remove_list: strategies_vector.remove(strategy) @@ -394,3 +398,87 @@ class StrategiesConstructor: setattr(node, 'strategies_vector', strategies_vector) self.leaf_strategies.append(strategies_vector) self.strategy_map[node] = strategies_vector + + +class StrategiesConstructor_V2: + """ + StrategiesConstructor is used to construct the parallelization plan for the model execution. + + Args: + graph (Graph): a Graph object used for analysis and strategy generation. + device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. + solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching. + """ + + def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions): + self.graph = graph + assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' + self.root_module = self.graph.owning_module + self.nodes = list(graph.nodes) + self.device_mesh = device_mesh + self.leaf_strategies = [] + self.strategy_map = {} + self.solver_options = solver_options + + def remove_duplicated_strategy(self, strategies_vector): + ''' + In build_strategies_and_cost method, we may produce some duplicated strategies. + In this method, we will remove the duplicated strategies depending on the strategies name. + Note that this operation is in-place. + ''' + name_checklist = [] + remove_list = [] + for strategy in strategies_vector: + if strategy.name not in name_checklist: + name_checklist.append(strategy.name) + else: + remove_list.append(strategy) + for strategy in remove_list: + strategies_vector.remove(strategy) + + def build_strategies_and_cost(self): + """ + This method is to build the strategy vector for each node in the computation graph. + """ + for node in self.nodes: + strategies_vector = StrategiesVector(node) + + # placeholder node + if node.op == 'placeholder': + # TODO: implement placeholder node handler + pass + + # get_attr node + elif node.op == 'get_attr': + # TODO: implement getattr node handler + pass + + # call_module node + elif node.op == 'call_module': + target = node.target + submod = self.root_module.get_submodule(target) + submod_type = type(submod) + handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector) + handler.register_strategy() + + # call_function node + elif node.op == 'call_function': + target = node.target + handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector) + handler.register_strategy() + + # call_method node + elif node.op == 'call_method': + method = getattr(node.args[0]._meta_data.__class__, node.target) + handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector) + handler.register_strategy() + + # output node + elif node.op == 'output': + # TODO: implement output node handler + pass + + self.remove_duplicated_strategy(strategies_vector) + setattr(node, 'strategies_vector', strategies_vector) + self.leaf_strategies.append(strategies_vector) + self.strategy_map[node] = strategies_vector