diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index 64b89346a..b1ec540d6 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -4,6 +4,7 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler from .layer_norm_handler import LayerNormModuleHandler from .linear_handler import LinearFunctionHandler, LinearModuleHandler +from .matmul_handler import MatMulHandler from .normal_pooling_handler import NormPoolingHandler from .output_handler import OuputHandler from .placeholder_handler import PlacehodlerHandler @@ -16,5 +17,5 @@ __all__ = [ 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', - 'NormPoolingHandler', 'BinaryElementwiseHandler', 'operator_registry' + 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py new file mode 100644 index 000000000..400c69693 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py @@ -0,0 +1,482 @@ +import operator +from abc import ABC, abstractmethod +from copy import deepcopy +from enum import Enum +from functools import reduce +from typing import Dict, List, Union + +import torch + +from colossalai.auto_parallel.tensor_shard.utils.broadcast import ( + BroadcastType, + get_broadcast_dim_info, + get_broadcast_shape, +) +from colossalai.tensor.sharding_spec import ShardingSpecException + +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from ..utils import recover_sharding_spec_for_broadcast_shape +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import ( + BatchedMatMulStrategyGenerator, + DotProductStrategyGenerator, + LinearProjectionStrategyGenerator, + MatVecStrategyGenerator, + StrategyGenerator, +) + + +class MatMulType(Enum): + """ + The MatMulType is categorized into 4 types based on the reference of torch.matmul + in https://pytorch.org/docs/stable/generated/torch.matmul.html. + + DOT: dot product, both tensors are 1D, these two tensors need to have the same number of elements + MM: matrix-matrix product, both tensors are 2D or the 1st tensor is 1D and the 2nd tensor is 2D + MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D + BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D + """ + DOT = 0 + MM = 1 + MV = 2 + BMM = 3 + + +def get_matmul_type(input_dim: int, other_dim: int): + """ + Determine which type of matmul operation should be executed for the given tensor dimensions. + + Args: + input_dim (int): the number of dimensions for the input tenosr + other_dim (int): the number of dimensions for the other tenosr + """ + if input_dim == 1 and other_dim == 1: + matmul_type = MatMulType.DOT + elif input_dim in [1, 2] and other_dim == 2: + matmul_type = MatMulType.MM + elif input_dim == 2 and other_dim == 1: + matmul_type = MatMulType.MV + elif input_dim >= 1 and other_dim >= 1 and (input_dim > 2 or other_dim > 2): + matmul_type = MatMulType.BMM + else: + raise ValueError( + f"The input and other tensors are of {input_dim} and {other_dim} which cannot used to execute matmul operation" + ) + return matmul_type + + +class BmmTransform(ABC): + """ + BmmTransform is an abstraction of the shape conversion between logical and physical operation data + during the strategy generation. + """ + + @abstractmethod + def apply(self, shape_mapping: Dict[str, List[int]]): + pass + + @abstractmethod + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + pass + + +class Padder(BmmTransform): + """ + Add padding to the matrix dimensions for batched matrix multiplication. + """ + + def __init__(self) -> None: + # keep the padding dim, op_name -> padded_dim + self.padded_dim_mapping = {} + + def apply(self, shape_mapping: Dict[str, List[int]]): + mapping_copy = deepcopy(shape_mapping) + input_shape = mapping_copy['input'] + other_shape = mapping_copy['other'] + + if len(input_shape) == 1: + # if the input is a 1D tensor, 1 is prepended to its shape + # and it will be removed afterwards + input_shape.insert(0, 1) + self.padded_dim_mapping['input'] = -2 + self.padded_dim_mapping['output'] = -2 + elif len(other_shape) == 1: + # if the other is a 1D tensor, 1 is appended to its shape + # and it will be removed afterwards + other_shape = other_shape.append(1) + self.padded_dim_mapping['other'] = -1 + self.padded_dim_mapping['output'] = -1 + return mapping_copy + + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + input_op_data = op_data_mapping['input'] + other_op_data = op_data_mapping['other'] + + def _remove_padded_dim(key, strategy): + op_data = op_data_mapping[key] + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + tensor_shape = list(sharding_spec.entire_shape) + dim_partition_list = [None] * len(tensor_shape) + + # padded dim is a negative number as the padded dim must be a matrix dim + padded_dim = self.padded_dim_mapping[key] + + # compute the new dim partition + for tensor_dim, mesh_dims in sharding_spec.dim_partition_dict.items(): + dim_partition_list[tensor_dim] = mesh_dims + dim_partition_list.pop(padded_dim) + unpadded_dim_partition_list = {k: v for k, v in enumerate(dim_partition_list) if v is not None} + + # compute unpadded tensor shape + tensor_shape.pop(padded_dim) + + assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}' + + # update sharding spec + sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list) + + # enumerate all sharding strategies + strategies = [] + try: + strategy_copy = strategy.clone() + + # only one of input and other will be padded + if 'input' in self.padded_dim_mapping: + _remove_padded_dim('input', strategy_copy) + _remove_padded_dim('output', strategy_copy) + elif 'other' in self.padded_dim_mapping: + _remove_padded_dim('other', strategy_copy) + _remove_padded_dim('output', strategy_copy) + + strategies.append(strategy_copy) + except ShardingSpecException as e: + pass + return strategies + + +class Broadcaster(BmmTransform): + """ + Broadcast the non-matrix dimensions for batched matrix multiplication. + """ + + def __init__(self) -> None: + self.broadcast_dim_info = {} + + def apply(self, shape_mapping: Dict[str, List[int]]): + mapping_copy = shape_mapping.copy() + + # get shapes + input_shape = mapping_copy['input'] + other_shape = mapping_copy['other'] + + # sanity check + assert len(input_shape) > 1 and len(other_shape) > 1 + + # broadcast the batch dim and record + bcast_non_matrix_dims = get_broadcast_shape(input_shape[:-2], other_shape[:-2]) + + # store the broadcast dim info + input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2]) + other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2]) + self.broadcast_dim_info['input'] = input_broadcast_dim_info + self.broadcast_dim_info['other'] = other_broadcast_dim_info + + # create the full logical shape + input_shape = bcast_non_matrix_dims + input_shape[-2:] + other_shape = bcast_non_matrix_dims + other_shape[-2:] + assert len(input_shape) == len(other_shape) + + mapping_copy['input'] = input_shape + mapping_copy['other'] = other_shape + + return mapping_copy + + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + # remove sharding on the broadcast dim + def _remove_sharding_on_broadcast_dim(key, strategy): + op_data = op_data_mapping[key] + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + tensor_shape = list(sharding_spec.entire_shape) + + for dim_idx, broadcast_type in self.broadcast_dim_info[key].items(): + if broadcast_type == BroadcastType.MULTIPLE: + # if the dim is originally 1 and multiplied during broadcast + # we set its sharding to R + # e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8] + # the dim 0 of [1, 2, 4] is multiplied to 4 + tensor_shape[dim_idx] = 1 + elif broadcast_type == BroadcastType.PADDDING: + # if the dim is padded + # we remove its sharding + tensor_shape[dim_idx] = None + + tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None] + + physical_sharding_spec = recover_sharding_spec_for_broadcast_shape( + logical_sharding_spec=sharding_spec, + logical_shape=sharding_spec.entire_shape, + physical_shape=tensor_shape_before_broadcast) + strategy.sharding_specs[op_data] = physical_sharding_spec + + # enumerate all sharding strategies + strategies = [] + try: + strategy_copy = strategy.clone() + _remove_sharding_on_broadcast_dim('input', strategy_copy) + _remove_sharding_on_broadcast_dim('other', strategy_copy) + strategies.append(strategy_copy) + except ShardingSpecException as e: + pass + return strategies + + +class Viewer(BmmTransform): + """ + Change the shape of the tensor from N-D to 3D + """ + + def __init__(self) -> None: + self.batch_dims_before_view = None + + def apply(self, shape_mapping: Dict[str, List[int]]): + mapping_copy = shape_mapping.copy() + self.batch_dims_before_view = list(mapping_copy['input'][:-2]) + + # get shapes + input_shape = shape_mapping['input'] + other_shape = shape_mapping['other'] + + # view to 3d tensor + assert len(input_shape) >= 3 and len(other_shape) >= 3 + input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:] + other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:] + output_shape = input_shape[:2] + other_shape[2:] + mapping_copy['input'] = input_shape + mapping_copy['other'] = other_shape + mapping_copy['output'] = output_shape + return mapping_copy + + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + # get operation data + def _update_sharding_spec(key, strategy, physical_batch_dim): + """ + Map the logical batch dim to the physical batch dim + """ + op_data = op_data_mapping[key] + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + dim_partition_dict = sharding_spec.dim_partition_dict + entire_shape = sharding_spec.entire_shape + + # upddate the dimension index for the matrix dimensions + if 2 in dim_partition_dict: + dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2) + if 1 in dim_partition_dict: + dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1) + + # map the logical batch dim to phyiscal batch dim + if 0 in dim_partition_dict: + batch_dim_shard = dim_partition_dict.pop(0) + dim_partition_dict[physical_batch_dim] = batch_dim_shard + + # the new shape will be the batch dims + the last 2 matrix dims + shape_before_view = self.batch_dims_before_view + list(entire_shape[-2:]) + sharding_spec.__init__(sharding_spec.device_mesh, shape_before_view, dim_partition_dict) + + num_batch_dim_before_view = len(self.batch_dims_before_view) + + # enumerate all sharding strategies + strategies = [] + for i in range(num_batch_dim_before_view): + # create a new strategy + strategy_copy = strategy.clone() + try: + _update_sharding_spec('input', strategy_copy, i) + _update_sharding_spec('other', strategy_copy, i) + _update_sharding_spec('output', strategy_copy, i) + strategies.append(strategy_copy) + except ShardingSpecException as e: + continue + return strategies + + +def _get_bmm_logical_shape(input_shape, other_shape, transforms): + """ + Compute the logical shapes for BMM operation. BMM has a general representation + [b, i, k] = [b, i, j] x [b, j, k] + + The dimension b is called non-matrix (batch) dimension and the remaining dimensions are called matrix dimensions + The logical shape for the bmm operands will undergo three stages + 1. append/prepend the 1 to the 1D tensor if there is any + 2. broadcast the non-matrix dimensions + 3. reshape to 3 dimensions + + """ + shape_mapping = {'input': input_shape, 'other': other_shape} + + for transform in transforms: + shape_mapping = transform.apply(shape_mapping) + + input_shape = shape_mapping.get('input', None) + other_shape = shape_mapping.get('other', None) + output_shape = shape_mapping.get('output', None) + + return input_shape, other_shape, output_shape + + +@operator_registry.register(torch.matmul) +@operator_registry.register(torch.Tensor.matmul) +class MatMulHandler(NodeHandler): + """ + The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation. + According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on + the operands. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # check which type of operation this matmul will call + self.input_meta_data = self.node.args[0]._meta_data + self.other_meta_data = self.node.args[1]._meta_data + self.output_meta_data = self.node._meta_data + + input_dim = self.input_meta_data.dim() + other_dim = self.other_meta_data.dim() + self.matmul_type = get_matmul_type(input_dim, other_dim) + + if self.matmul_type == MatMulType.BMM: + # bmm operation can possibly involve padding, broadcasting and view + # these transforms will be used to create logical shape and + # recover physical sharding spec + self.transforms = [Padder(), Broadcaster(), Viewer()] + else: + self.transforms = None + + def get_strategy_generator(self) -> List[StrategyGenerator]: + generators = [] + op_data_mapping = self.get_operation_data_mapping() + if self.matmul_type == MatMulType.BMM: + generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)) + elif self.matmul_type == MatMulType.DOT: + generators.append(DotProductStrategyGenerator(op_data_mapping, self.device_mesh)) + elif self.matmul_type == MatMulType.MV: + generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh)) + elif self.matmul_type == MatMulType.MM: + generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + logical_shape_func = { + MatMulType.DOT: self._get_logical_shape_for_dot, + MatMulType.MM: self._get_logical_shape_for_mm, + MatMulType.MV: self._get_logical_shape_for_mv, + MatMulType.BMM: self._get_logical_shape_for_bmm + } + logical_shapes = logical_shape_func[self.matmul_type]() + op_data_mapping = self._get_op_data_mapping(*logical_shapes) + return op_data_mapping + + def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_logical_shape): + # convert list to torch.Size + if input_logical_shape: + input_logical_shape = torch.Size(input_logical_shape) + + if other_logical_shape: + other_logical_shape = torch.Size(other_logical_shape) + + if output_logical_shape: + output_logical_shape = torch.Size(output_logical_shape) + + # create op data + input_op_data = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.input_meta_data, + logical_shape=input_logical_shape) + other_op_data = OperationData(name=str(self.node.args[1]), + type=OperationDataType.ARG, + data=self.other_meta_data, + logical_shape=other_logical_shape) + output_op_data = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=self.output_meta_data, + logical_shape=output_logical_shape) + + mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} + return mapping + + def _get_logical_shape_for_dot(self): + """ + The operands for the dot operation have the same logical shape as the physical shape + """ + return None, None, None + + def _get_logical_shape_for_mm(self): + """ + We need to handle the input tensor for a matrix-matrix multiplcation as the input + tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape + (e.g. [4] -> [1, 4]). + """ + if self.input_meta_data.dim() == 1: + input_logical_shape = [1] + list(self.input_meta_data.shape) + input_logical_shape = torch.Size(input_logical_shape) + else: + input_logical_shape = None + return input_logical_shape, None, None + + def _get_logical_shape_for_mv(self): + """ + No broadcasting or dim insertion occurs for matrix-vector operation. + """ + return None, None, None + + def _get_logical_shape_for_bmm(self): + input_physical_shape = list(self.input_meta_data.shape) + other_physical_shape = list(self.other_meta_data.shape) + return _get_bmm_logical_shape(input_physical_shape, other_physical_shape, self.transforms) + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + if self.matmul_type in [MatMulType.DOT, MatMulType.MV]: + return strategy + elif self.matmul_type == MatMulType.MM: + if self.input_meta_data.dim() == 1: + # if a 1 is prepended to the input shape (this occurs when input is a 1D tensor) + # we need to remove that dim + input_sharding_spec = strategy.get_sharding_spec_by_name(str(self.node.args[0])) + input_physical_shape = self.node.args[0]._meta_data.shape + dim_partition_dict = input_sharding_spec.dim_partition_dict + + # remove the partitioning in the dim 0 + if 0 in dim_partition_dict: + dim_partition_dict.pop(0, None) + + # move the partitioning in dim 1 to dim 0 + if -1 in dim_partition_dict: + shard = dim_partition_dict.pop(-1) + dim_partition_dict[0] = shard + + # re-init the sharding spec + input_sharding_spec.__init__(input_sharding_spec.device_mesh, + entire_shape=input_physical_shape, + dim_partition_dict=dim_partition_dict) + return strategy + else: + return strategy + elif self.matmul_type == MatMulType.BMM: + op_data_mapping = self.get_operation_data_mapping() + + strategies = [strategy] + # recover the physical sharding spec + for transform in self.transforms[::-1]: + recovered_stragies = [] + for strategy_ in strategies: + output = transform.recover(op_data_mapping, strategy_) + if isinstance(output, ShardingStrategy): + recovered_stragies.append(output) + elif isinstance(output, (list, tuple)): + recovered_stragies.extend(output) + else: + raise TypeError( + f"Found unexpected output type {type(output)} from the recover method of BmmTransform") + strategies = recovered_stragies + return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index 11b883873..b12e9c08d 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -60,12 +60,13 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() fwd_compute_cost = sharded_input_shape[0] - bwd_compute_cost = sharded_input_shape * 2 + bwd_compute_cost = fwd_compute_cost * 2 compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) return compute_cost + @ignore_sharding_exception def no_split(self): name = f'R = R dot R' dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}} @@ -75,6 +76,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_one_dim(self, mesh_dim): name = f'R = S{mesh_dim} dot S{mesh_dim}' @@ -93,7 +95,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - def generate(self) -> List[ShardingStrategy]: + def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] # do not split dimensions for dot product @@ -113,24 +115,50 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): def validate(self) -> bool: input_op_data = self.op_data['input'] other_op_data = self.op_data['other'] - assert input_op_data.data.dim() > 1 and other_op_data.data.dim() == 1 + assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1 + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + fwd_compute_cost = sharded_input_shape[0] + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + return compute_cost + + @ignore_sharding_exception def no_split(self): name = "R = R x R" - dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}} + dim_partition_dict = {"input": {}, "other": {}, "output": {}} + + if self.has_bias: + dim_partition_dict['bias'] = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}) + @ignore_sharding_exception def split_input_batch(self, mesh_dim): name = f'S{mesh_dim}R = S{mesh_dim}R x R' # get sharding spec - dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}, "bias": {}} + dim_partition_dict = { + "input": { + 0: [mesh_dim] + }, + "other": {}, + "output": { + 0: [mesh_dim] + }, + } + + if self.has_bias: + dim_partition_dict['bias'] = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication action + communication_action_mapping = {} if self.is_param('other'): other_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['other'], @@ -144,6 +172,8 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): logical_process_axis=mesh_dim, comm_type=CommType.BEFORE, arg_index=1) + communication_action_mapping['other'] = other_comm_action + if self.has_bias: if self.is_param('bias'): bias_comm_action = self.get_communication_action( @@ -158,13 +188,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): logical_process_axis=mesh_dim, comm_type=CommType.BEFORE, arg_index=2) - communication_action_mapping = {'other': other_comm_action, 'bias': bias_comm_action} + communication_action_mapping['bias'] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - def generate(self) -> List[ShardingStrategy]: + def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] # no split @@ -638,7 +668,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): def validate(self) -> bool: input_op_data = self.op_data['input'] other_op_data = self.op_data['other'] - assert input_op_data.data.dim() == 3 or other_op_data.data.dim() == 3 + assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3 if 'bias' in self.op_data: bias_op_data = self.op_data['bias'] @@ -816,11 +846,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): dim_partition_dict = { "input": { 0: [mesh_dim_0], - -1: [mesh_dim_1] + 2: [mesh_dim_1] }, "other": { 0: [mesh_dim_0], - -2: [mesh_dim_1] + 1: [mesh_dim_1] }, "bias": {}, "output": { diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py index b3903b9d7..096bda619 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -186,9 +186,14 @@ class StrategyGenerator(ABC): """ op_data = self.op_data[key] sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device() + + if len(sharded_shape) == 0: + num_elements = 1 + else: + num_elements = reduce(operator.mul, sharded_shape) dtype = self.op_data[key].data.dtype size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - return reduce(operator.mul, sharded_shape) * size_per_elem_bytes + return num_elements * size_per_elem_bytes def generate(self) -> List[ShardingStrategy]: """ diff --git a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py index d452cff0c..3a3753b00 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py +++ b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py @@ -44,21 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]: return dims[::-1] -def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, - physical_shape: torch.Size) -> ShardingSpec: - """ - This function computes the sharding spec for the physical shape of a broadcast tensor. - - Args: - logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor - logical_shape (torch.Size): logical shape is the broadcast shape of a tensor - physical_shape (torch.Size): the shape of the tensor before broadcasting - """ - # if the two shapes are the same, no broadcast occurs - # we directly return the current sharding spec - if list(logical_shape) == list(physical_shape): - return logical_sharding_spec - +def get_broadcast_dim_info(logical_shape, physical_shape): # get the number of dimensions logical_num_dims = len(logical_shape) physical_num_dims = len(physical_shape) @@ -85,6 +71,31 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe else: logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING + return logical_dim_broadcast_info + + +def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, + physical_shape: torch.Size) -> ShardingSpec: + """ + This function computes the sharding spec for the physical shape of a broadcast tensor. + + Args: + logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor + logical_shape (torch.Size): logical shape is the broadcast shape of a tensor + physical_shape (torch.Size): the shape of the tensor before broadcasting + """ + # if the two shapes are the same, no broadcast occurs + # we directly return the current sharding spec + if list(logical_shape) == list(physical_shape): + return logical_sharding_spec + + # get the number of dimensions + logical_num_dims = len(logical_shape) + physical_num_dims = len(physical_shape) + + # get the broadcast info + logical_dim_broadcast_info = get_broadcast_dim_info(logical_shape, physical_shape) + # generate the sharding spec for the physical shape physical_dim_partition = {} logical_dim_partition = logical_sharding_spec.dim_partition_dict diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index 37d397885..c8bce731e 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -1,6 +1,5 @@ import operator from copy import deepcopy -from enum import Enum from functools import reduce import torch @@ -175,6 +174,9 @@ class ShardingSpec: dim_partition_dict=None, sharding_sequence=None): self.device_mesh = device_mesh + + if isinstance(entire_shape, (list, tuple)): + entire_shape = torch.Size(entire_shape) self.entire_shape = entire_shape self.dim_partition_dict = dim_partition_dict self.sharding_sequence = sharding_sequence diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py new file mode 100644 index 000000000..306c45f56 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import ( + MatMulHandler, + MatMulType, + _get_bmm_logical_shape, + get_matmul_type, +) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing.utils import parameterize + + +class MatMulModule(nn.Module): + + def forward(self, x1, x2): + return torch.matmul(x1, x2) + + +@parameterize( + 'tensor_shapes', + [ + [[8], [8]], # dot product + [[4, 8], [8]], # mat-vec product + [[4, 8], [8, 16]], # mat-mat product + [[8], [8, 16]], # mat-mat product + [[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting + [[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting + [[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting + [[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting + [[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting + ]) +def test_matmul_node_handler(tensor_shapes): + input_shape, other_shape = tensor_shapes + + # get output shape + x1 = torch.rand(*input_shape) + x2 = torch.rand(*other_shape) + output_shape = list(torch.matmul(x1, x2).shape) + + # get matmul type + matmul_type = get_matmul_type(x1.dim(), x2.dim()) + + model = MatMulModule() + + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={"x1": x1.to('meta'), 'x2': x2.to('meta')}) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + print(graph) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + mod_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(mod_node) + + # build handler + handler = MatMulHandler(node=mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + logical_input_shape = input_shape + logical_other_shape = other_shape + logical_output_shape = output_shape + if matmul_type == MatMulType.MM and len(input_shape) == 1: + logical_input_shape = [1] + input_shape + elif matmul_type == MatMulType.BMM: + logical_input_shape, logical_other_shape, logical_output_shape = _get_bmm_logical_shape( + input_shape, other_shape, handler.transforms) + else: + logical_input_shape = input_shape + + # check input operation data + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size(input_shape) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size(logical_input_shape) + + # check other operation data + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size(other_shape) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size(logical_other_shape) + + # check output + assert mapping['output'].name == "matmul" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size(output_shape) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size(logical_output_shape) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # ensure there is no duplicate strategy + if matmul_type != MatMulType.BMM: + assert len(set(strategy_name_list)) == len(strategy_name_list), strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name('matmul') + + if matmul_type == MatMulType.DOT: + # dot product will produce a scaler + # results should fulfill: + # 1. the input and other operands have the same sharding spec + # 2. the output has no sharding + assert input_sharding_spec.sharding_sequence == other_sharding_spec.sharding_sequence + assert len(output_sharding_spec.sharding_sequence) == 0 + elif matmul_type == MatMulType.MV: + # matrix-vector product should fulfill + # 1. the last dim of the input and other operands should have the same sharding + # 2. the first dim of the input and other should have the same sharding + # 3. the output should have only 1 dim + assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1] + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert len(output_sharding_spec.sharding_sequence) == 1 + elif matmul_type == MatMulType.MM: + # matrix-matrix multiplication should fulfil + # 1. if input is a 2D tensor, the 1st dim of input and output should have the same sharding + # 2. the input's last dim and the first dim of the other should have the same sharding + # 3. the last dim of the output and other should have the same sharding + # 4. the input and output should have the same number of dims + if len(input_shape) == 2: + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[0] + assert output_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1] + assert len(input_sharding_spec.sharding_sequence) == len(output_sharding_spec.sharding_sequence) + elif matmul_type == MatMulType.BMM: + # bmm should fulfil + # 1. of the other tensor is not a 1d tensor, the last dim of other and output have the same sharding + # 2. if the input has more than 2 dim, the second last dim of input and output have the same sharding + # 3. if the other have more than 2 dim, the second last dim of other and the last dim of input should have the same sharding + if len(other_shape) > 1: + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + if len(input_shape) > 1: + assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2] + if len(other_shape) > 2: + assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1] + + +if __name__ == '__main__': + test_matmul_node_handler()