From f9a613d66071a2962688fcec5f9e19164b23ce26 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 25 Oct 2022 14:32:01 +0800 Subject: [PATCH] [autoparallel] added binary elementwise node handler (#1758) * [autoparallel] added binary elementwise node handler * polish code --- .../auto_parallel/tensor_shard/constants.py | 5 +- .../tensor_shard/node_handler/__init__.py | 3 +- .../binary_elementwise_handler.py | 86 +++++++++ .../tensor_shard/node_handler/registry.py | 7 +- .../node_handler/strategy/__init__.py | 13 +- .../strategy/binary_elementwise_generator.py | 111 +++++++++++ .../tensor_shard/utils/broadcast.py | 5 + .../test_binary_elementwise_handler.py | 173 ++++++++++++++++++ 8 files changed, 395 insertions(+), 8 deletions(-) create mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py create mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py diff --git a/colossalai/auto_parallel/tensor_shard/constants.py b/colossalai/auto_parallel/tensor_shard/constants.py index 91c20d343..9143ad9db 100644 --- a/colossalai/auto_parallel/tensor_shard/constants.py +++ b/colossalai/auto_parallel/tensor_shard/constants.py @@ -1,6 +1,7 @@ -import torch import operator +import torch + __all__ = [ 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP', @@ -35,7 +36,7 @@ RESHAPE_METHOD_OP = [ ] BCAST_FUNC_OP = [ torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub, - operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh + operator.mul, operator.floordiv, operator.truediv, torch.matmul, operator.pow, torch.pow ] CONV_MODULE_OP = [ torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index b9227e2ec..64b89346a 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -1,4 +1,5 @@ from .batch_norm_handler import BatchNormModuleHandler +from .binary_elementwise_handler import BinaryElementwiseHandler from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler from .layer_norm_handler import LayerNormModuleHandler @@ -15,5 +16,5 @@ __all__ = [ 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', - 'NormPoolingHandler', 'operator_registry' + 'NormPoolingHandler', 'BinaryElementwiseHandler', 'operator_registry' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py new file mode 100644 index 000000000..798e677eb --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -0,0 +1,86 @@ +from typing import Dict, List, Union + +import torch +from torch.fx.node import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy + +from ..constants import BCAST_FUNC_OP +from ..utils import recover_sharding_spec_for_broadcast_shape +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator + +__all__ = ['BinaryElementwiseHandler'] + + +@operator_registry.register(BCAST_FUNC_OP) +class BinaryElementwiseHandler(NodeHandler): + """ + An BinaryBcastOpHandler is a node handler which deals with operations which have two + operands and broadcasting occurs such as torch.add. + """ + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + bcast_shape = self.node._meta_data.shape + + def _get_op_data_type(tensor): + if isinstance(tensor, torch.nn.parameter.Parameter): + return OperationDataType.PARAM + else: + return OperationDataType.ARG + + def _get_arg_value(idx): + if isinstance(self.node.args[idx], Node): + meta_data = self.node.args[idx]._meta_data + else: + # this is in fact a real data like int 1 + # but we can deem it as meta data + # as it won't affect the strategy generation + assert isinstance(self.node.args[idx], (int, float)) + meta_data = torch.Tensor([self.node.args[idx]]).to('meta') + return meta_data + + input_meta_data = _get_arg_value(0) + other_meta_data = _get_arg_value(1) + output_meta_data = self.node._meta_data + + input_op_data = OperationData(name=str(self.node.args[0]), + type=_get_op_data_type(input_meta_data), + data=input_meta_data, + logical_shape=bcast_shape) + other_op_data = OperationData(name=str(self.node.args[1]), + type=_get_op_data_type(other_meta_data), + data=other_meta_data, + logical_shape=bcast_shape) + output_op_data = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=bcast_shape) + + mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} + return mapping + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(BinaryElementwiseStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + # convert bias from its logical sharding spec to its physical sharding spec + op_data_mapping = self.get_operation_data_mapping() + + for op_name, op_data in op_data_mapping.items(): + if not isinstance(op_data.data, torch.Tensor): + # remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2) + strategy.sharding_specs.pop(op_data) + else: + # convert the logical sharding spec to physical sharding spec if broadcast + # e.g. torch.rand(4, 4) + torch.rand(4) + physical_shape = op_data.data.shape + logical_shape = op_data.logical_shape + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + sharding_spec = recover_sharding_spec_for_broadcast_shape(sharding_spec, logical_shape, physical_shape) + strategy.sharding_specs[op_data] = sharding_spec + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py index 6bed842d4..8e06cec4f 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py @@ -8,7 +8,12 @@ class Registry: def register(self, source): def wrapper(func): - self.store[source] = func + if isinstance(source, (list, tuple)): + # support register a list of items for this func + for element in source: + self.store[element] = func + else: + self.store[source] = func return func return wrapper diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py index f137f09db..28ee05c0e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py @@ -1,9 +1,14 @@ from .batch_norm_generator import BatchNormStrategyGenerator +from .binary_elementwise_generator import BinaryElementwiseStrategyGenerator from .conv_strategy_generator import ConvStrategyGenerator -from .getitem_generator import (GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator) +from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator from .layer_norm_generator import LayerNormGenerator -from .matmul_strategy_generator import (BatchedMatMulStrategyGenerator, DotProductStrategyGenerator, - LinearProjectionStrategyGenerator, MatVecStrategyGenerator) +from .matmul_strategy_generator import ( + BatchedMatMulStrategyGenerator, + DotProductStrategyGenerator, + LinearProjectionStrategyGenerator, + MatVecStrategyGenerator, +) from .normal_pooling_generator import NormalPoolStrategyGenerator from .output_generator import OutputGenerator from .placeholder_generator import PlaceholderGenerator @@ -17,5 +22,5 @@ __all__ = [ 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', - 'ReshapeGenerator', 'NormalPoolStrategyGenerator' + 'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py new file mode 100644 index 000000000..fd7f811c8 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py @@ -0,0 +1,111 @@ +import operator +from functools import reduce +from typing import List + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) +from colossalai.tensor.sharding_spec import ShardingSpecException + +from .strategy_generator import StrategyGenerator + +__all__ = ['BinaryElementwiseStrategyGenerator'] + + +class BinaryElementwiseStrategyGenerator(StrategyGenerator): + """ + An BinaryElementwiseStrategyGenerator is a node handler which deals with elementwise operations + which have two operands and broadcasting occurs such as torch.add. + + The logical shape for this operation will be `input other`. + """ + + def validate(self) -> bool: + assert len(self.op_data) == 3, \ + f'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}' + for name, op_data in self.op_data.items(): + if not isinstance(op_data.data, (torch.Tensor, int, float)): + raise TypeError(f'The operation data {name} is not a torch.Tensor/int/float.') + + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + + # since elementwise ops are not compute-intensive, + # we approximate the backward compute cost + # to be twice the fwd compute cost + fwd_compute_cost = reduce(operator.mul, shape) + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + # all input, output and outputs have the same shape + shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + + # compute fwd memory cost in bytes + # as the elementwise ops are not memory-intensive + # we approximate the fwd memroy cost to be the output + # and the backward memory cost to be grad of input and other + input_bytes = self._compute_size_in_bytes(strategy, 'input') + other_bytes = self._compute_size_in_bytes(strategy, 'other') + output_bytes = self._compute_size_in_bytes(strategy, 'output') + fwd_memory_cost = MemoryCost(activation=output_bytes) + bwd_memory_cost = MemoryCost(activation=input_bytes + other_bytes) + total_memory_cost = MemoryCost(activation=input_bytes + other_bytes + output_bytes) + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_memory_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): + # we check for the output logical shape to get the number of dimensions + dim_partition_list = [] + dim_size = len(self.op_data['output'].logical_shape) + + # enumerate all the 2D sharding cases + sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) + dim_partition_list.extend(sharding_list_2d) + + # enumerate all the 1D sharding cases + sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size) + dim_partition_list.extend(sharding_list_1d_on_dim_0) + sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size) + dim_partition_list.extend(sharding_list_1d_on_dim_1) + + # add empty dict for fully replicated case + dim_partition_list.append({}) + + # sharding strategy bookkeeping + strategy_list = [] + + # convert these dim partition dict to sharding strategy + for dim_partition_dict in dim_partition_list: + dim_partition_dict_mapping = dict(input=dim_partition_dict, + other=dim_partition_dict, + output=dim_partition_dict) + + try: + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + communication_action_mapping = {} + + # get name + sharding_seq = sharding_spec_mapping['input'].sharding_sequence + name = f'{sharding_seq} = {sharding_seq} {sharding_seq}' + sharding_strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(sharding_strategy) + except ShardingSpecException: + continue + return strategy_list + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = self.enumerate_all_possible_output(0, 1) + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py index a0edce9b9..d452cff0c 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py +++ b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py @@ -54,6 +54,11 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe logical_shape (torch.Size): logical shape is the broadcast shape of a tensor physical_shape (torch.Size): the shape of the tensor before broadcasting """ + # if the two shapes are the same, no broadcast occurs + # we directly return the current sharding spec + if list(logical_shape) == list(physical_shape): + return logical_sharding_spec + # get the number of dimensions logical_num_dims = len(logical_shape) physical_num_dims = len(physical_shape) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py new file mode 100644 index 000000000..6cc49cb6e --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py @@ -0,0 +1,173 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler import BinaryElementwiseHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import parameterize + + +@parameterize('op', [torch.add]) +@parameterize('other_dim', [1, 2]) +def test_binary_elementwise_handler_with_tensor(op, other_dim): + + class BinaryElementwiseOpModel(nn.Module): + + def __init__(self, op): + super().__init__() + self.op = op + + def forward(self, x1, x2): + out = self.op(x1, x2) + return out + + model = BinaryElementwiseOpModel(op) + tracer = ColoTracer() + + meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + print(graph) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + op_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(op_node) + + # build handler + handler = BinaryElementwiseHandler(node=op_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4]) + + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([4] * other_dim) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 4]) + + assert mapping['output'].name == str(op_node) + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 4]) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size([4, 4]) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # one strategy will be converted to different physical sharding spec + assert len(strategy_name_list) == 9 + + # check if the sharding strategy is correct + assert '[S0, S1] = [S0, S1] [S0, S1]' in strategy_name_list + assert '[S1, S0] = [S1, S0] [S1, S0]' in strategy_name_list + assert '[S01, R] = [S01, R] [S01, R]' in strategy_name_list + assert '[R, S01] = [R, S01] [R, S01]' in strategy_name_list + assert '[S0, R] = [S0, R] [S0, R]' in strategy_name_list + assert '[R, S0] = [R, S0] [R, S0]' in strategy_name_list + assert '[S1, R] = [S1, R] [S1, R]' in strategy_name_list + assert '[R, S1] = [R, S1] [R, S1]' in strategy_name_list + assert '[R, R] = [R, R] [R, R]' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node)) + + # make sure the sharding spec is the same for input and output + assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence + + # since the dim of the other can change, we make sure at least its last dim sharding is the same + if len(other_sharding_spec.sharding_sequence) == 2: + assert input_sharding_spec.sharding_sequence == other_sharding_spec.sharding_sequence + elif len(other_sharding_spec.sharding_sequence) == 1: + assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1] + + +@parameterize('op', [torch.add]) +@parameterize('other', [1, 2]) +def test_binary_elementwise_handler_with_int(op, other): + + class BinaryElementwiseOpModel(nn.Module): + + def __init__(self, op, const): + super().__init__() + self.op = op + self.const = const + + def forward(self, x1): + out = self.op(x1, self.const) + return out + + model = BinaryElementwiseOpModel(op, other) + tracer = ColoTracer() + + meta_args = {'x1': torch.rand(4, 4).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + print(graph) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + op_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(op_node) + + # build handler + handler = BinaryElementwiseHandler(node=op_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4]) + + assert mapping['output'].name == str(op_node) + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 4]) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size([4, 4]) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # one strategy will be converted to different physical sharding spec + assert len(strategy_name_list) == 9 + + # check if the sharding strategy is correct + assert '[S0, S1] = [S0, S1] [S0, S1]' in strategy_name_list + assert '[S1, S0] = [S1, S0] [S1, S0]' in strategy_name_list + assert '[S01, R] = [S01, R] [S01, R]' in strategy_name_list + assert '[R, S01] = [R, S01] [R, S01]' in strategy_name_list + assert '[S0, R] = [S0, R] [S0, R]' in strategy_name_list + assert '[R, S0] = [R, S0] [R, S0]' in strategy_name_list + assert '[S1, R] = [S1, R] [S1, R]' in strategy_name_list + assert '[R, S1] = [R, S1] [R, S1]' in strategy_name_list + assert '[R, R] = [R, R] [R, R]' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node)) + + # make sure the sharding spec is the same for input and output + assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence + + +if __name__ == '__main__': + test_binary_elementwise_handler_with_tensor() + test_binary_elementwise_handler_with_int()