diff --git a/colossalai/auto_parallel/solver/op_handler/__init__.py b/colossalai/auto_parallel/solver/op_handler/__init__.py index f51bfd739..9d7315dd1 100644 --- a/colossalai/auto_parallel/solver/op_handler/__init__.py +++ b/colossalai/auto_parallel/solver/op_handler/__init__.py @@ -4,5 +4,9 @@ from .conv_handler import ConvHandler from .batch_norm_handler import BatchNormHandler from .reshape_handler import ReshapeHandler from .bcast_op_handler import BcastOpHandler +from .unary_elementwise_handler import UnaryElementwiseHandler -__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler'] \ No newline at end of file +__all__ = [ + 'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler', + 'UnaryElementwiseHandler' +] diff --git a/colossalai/auto_parallel/solver/op_handler/operator_handler.py b/colossalai/auto_parallel/solver/op_handler/operator_handler.py index 52899c742..a0b70bd6f 100644 --- a/colossalai/auto_parallel/solver/op_handler/operator_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/operator_handler.py @@ -47,7 +47,10 @@ class OperatorHandler(ABC): elif self.node.op == 'call_function' and self.node.target not in NON_PARAM_FUNC_OP: module = None parameters = list(self.node.args)[1] - named_parameters = {'weight': parameters._meta_data} + if isinstance(parameters, Node): + named_parameters = {'weight': parameters._meta_data} + else: + named_parameters = {} else: module = None named_parameters = None diff --git a/colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler.py new file mode 100644 index 000000000..c6fbc3d2c --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler.py @@ -0,0 +1,83 @@ +import operator +from functools import reduce +import warnings +import torch +from colossalai.auto_parallel.solver.constants import INFINITY_COST +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from .operator_handler import OperatorHandler +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec +from copy import deepcopy +from typing import Dict, List +import math +from colossalai.auto_parallel.solver._utils import exception_handler + +__all__ = ['UnaryElementwiseHandler'] + + +class UnaryElementwiseHandler(OperatorHandler): + """ + An OperatorHandler which deals with the sharding strategies of UnaryElementwiseOp. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.node.op == 'call_module': + target = self.node.target + submod = self.node.graph.owning_module.get_submodule(target) + submod_type = type(submod) + if submod_type == torch.nn.Dropout: + print(f'predecessor nodes of dropout node are {self.predecessor_node}') + input_nodes_len = 0 + for check_node in self.predecessor_node: + if isinstance(check_node._meta_data, torch.Tensor): + input_nodes_len += 1 + assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op, node name is {self.node}, node args is {self.node.args}.' + self.input_data = self.predecessor_node[0]._meta_data + self.input_node = self.predecessor_node[0] + self.output_data = self.node._meta_data + + def _generate_compute_cost(self, *args, **kwargs): + return super()._generate_compute_cost(*args, **kwargs) + + @exception_handler + def register_strategy(self): + # TODO: integrate element-wise func and module together + # create sharding strategy for element-wise function + + # For element-wise function, we keep the sharding spec of output node same as + # the input. Therefore, the different strategies of input node with same + # output sharding spec will generate same strategy for element-wise function. + sharding_spec_checklist = [] + for strategy in self.input_node.strategies_vector: + # It looks a little bit confusing, the input of the processing node + # is the output of the input_node. + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + if input_sharding_spec in sharding_spec_checklist: + continue + sharding_spec_checklist.append(input_sharding_spec) + dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) + try: + output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict) + except AssertionError as e: + warnings.warn(f'{e}') + continue + name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}' + # TODO: use meta_info_prop to profile memory cost and compute cost + compute_cost = self.output_data.numel() + memory_cost = 0 + + resharding_costs = self._generate_resharding_costs([input_sharding_spec]) + + # to prevent the resharding happening, set their resharding cost to inf. + resharding_costs[self.input_node] = [ + 0 if cost == 0 else INFINITY_COST for cost in resharding_costs[self.input_node] + ] + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[input_sharding_spec]) + self.strategies_vector.append(sharding_strategy) diff --git a/colossalai/auto_parallel/solver/strategies_constructor.py b/colossalai/auto_parallel/solver/strategies_constructor.py index ab67d37e9..ed3da6c8c 100644 --- a/colossalai/auto_parallel/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/solver/strategies_constructor.py @@ -14,6 +14,7 @@ import torch import operator from typing import Dict, List from ._utils import generate_sharding_spec, generate_resharding_costs +import builtins class StrategiesConstructor: @@ -63,7 +64,11 @@ class StrategiesConstructor: def build_strategies_and_cost(self): for node in self.nodes: strategies_vector = StrategiesVector(node) - input_nodes_len = len(strategies_vector.predecessor_nodes) + input_nodes_len = 0 + for check_node in strategies_vector.predecessor_nodes: + if isinstance(check_node._meta_data, torch.Tensor): + input_nodes_len += 1 + # input_nodes_len = len(strategies_vector.predecessor_nodes) # placeholder node if node.op == 'placeholder': # For placeholder nodes, if solver_options.fast is True, we just let them in @@ -122,53 +127,12 @@ class StrategiesConstructor: # element-wise module elif submod_type in ELEMENTWISE_MODULE_OP: - # create sharding strategy for element-wise module - assert len(strategies_vector.predecessor_nodes - ) == 1, f'Temporally, we just support single input element-wise op.' - input_node = strategies_vector.predecessor_nodes[0] - # For element-wise module, we keep the sharding spec of output node same as - # the input. Therefore, the different strategies of input node with same - # output sharding spec will generate same strategy for element-wise module. - sharding_spec_checklist = [] - for strategy in input_node.strategies_vector: - # It looks a little bit confusing, the input of the processing node - # is the output of the input_node. - input_sharding_spec = strategy.output_sharding_spec - assert isinstance(input_sharding_spec, - ShardingSpec), f'The input node should NOT be a tuple of tensor.' - if input_sharding_spec in sharding_spec_checklist: - continue - - sharding_spec_checklist.append(input_sharding_spec) - dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) - output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) - - name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}' - - # TODO: use meta_info_prop to profile memory cost and compute cost - compute_cost = node._meta_data.numel() - memory_cost = 0 - resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, - [input_sharding_spec]) - - # to prevent the resharding happening, set their resharding cost to inf. - resharding_costs[input_node] = [ - cost if cost == 0 else math.inf for cost in resharding_costs[input_node] - ] - sharding_strategy = ShardingStrategy(name, - output_sharding_spec, - compute_cost=compute_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=[input_sharding_spec]) - strategies_vector.append(sharding_strategy) + unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) + unary_elementwise_handler.register_strategy() # BatchNormNd module elif submod_type in BATCHNORM_MODULE_OP: - # bn1 call_module bn1 (conv1,) - # print(node, node.op, node.target, node.args) # create sharding strategy for element-wise module - # input_node = strategies_vector.predecessor_nodes[0] norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector) norm_handler.register_strategy() # for strategy in norm_handler.strategies_vector: @@ -181,8 +145,7 @@ class StrategiesConstructor: # e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension # create sharding strategy for element-wise module - assert len(strategies_vector.predecessor_nodes - ) == 1, f'Temporally, we just support single input element-wise op.' + assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op.' input_node = strategies_vector.predecessor_nodes[0] # For element-wise module, we keep the sharding spec of output node same as # the input. Therefore, the different strategies of input node with same @@ -255,50 +218,15 @@ class StrategiesConstructor: # element-wise function elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1): - # TODO: integrate element-wise func and module together - # create sharding strategy for element-wise function - assert len(strategies_vector.predecessor_nodes - ) == 1, f'Temporally, we just support single input element-wise op, node name is {node}.' - input_node = strategies_vector.predecessor_nodes[0] - # For element-wise function, we keep the sharding spec of output node same as - # the input. Therefore, the different strategies of input node with same - # output sharding spec will generate same strategy for element-wise function. - sharding_spec_checklist = [] - for strategy in input_node.strategies_vector: - # It looks a little bit confusing, the input of the processing node - # is the output of the input_node. - input_sharding_spec = strategy.output_sharding_spec - assert isinstance(input_sharding_spec, - ShardingSpec), f'The input node should NOT be a tuple of tensor.' - if input_sharding_spec in sharding_spec_checklist: - continue - sharding_spec_checklist.append(input_sharding_spec) - dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) - output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) - name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}' - # TODO: use meta_info_prop to profile memory cost and compute cost - compute_cost = node._meta_data.numel() - memory_cost = 0 - - resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, - [input_sharding_spec]) - - # to prevent the resharding happening, set their resharding cost to inf. - resharding_costs[input_node] = [ - 0 if cost == 0 else math.inf for cost in resharding_costs[input_node] - ] - sharding_strategy = ShardingStrategy(name, - output_sharding_spec, - compute_cost=compute_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=[input_sharding_spec]) - strategies_vector.append(sharding_strategy) + if isinstance(node._meta_data, torch.Tensor): + unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) + unary_elementwise_handler.register_strategy() # bcast op elif target in BCAST_FUNC_OP: - bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector) - bcast_op_handler.register_strategy() + if isinstance(node._meta_data, torch.Tensor): + bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector) + bcast_op_handler.register_strategy() # torch.var_mean elif target == torch.var_mean: @@ -421,7 +349,10 @@ class StrategiesConstructor: elif method in RESHAPE_METHOD_OP: reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector) reshape_handler.register_strategy() - + # print(strategies_vector) + # if len(strategies_vector) == 0: + # print(node) + # assert False else: raise RuntimeError(f'{method} function is NOT supported now.')