From 3abf98a6337ae39f11b3c259a0af8d40477fe7f7 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 16 Sep 2022 10:47:32 +0800 Subject: [PATCH] [autoparallel] added all non-bcast matmul strategies (#1603) --- .../solver/op_handler/dot_handler.py | 242 +++++++++++++++++- .../solver/op_handler/strategy_generator.py | 12 +- 2 files changed, 251 insertions(+), 3 deletions(-) diff --git a/colossalai/auto_parallel/solver/op_handler/dot_handler.py b/colossalai/auto_parallel/solver/op_handler/dot_handler.py index 38b3be16f..f29772705 100644 --- a/colossalai/auto_parallel/solver/op_handler/dot_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/dot_handler.py @@ -13,9 +13,238 @@ from typing import List __all__ = ['DotHandler'] +class DotProductStrategyGenerator(StrategyGenerator): + """ + DotProductStrategyGenerator is used to generate the sharding strategies for two 1D tensors in dot product computation. + This is created for torch.matmul where two tensors are 1D tensors. As torch.matmul does not include a bias argument, so we + do not consider bias here. + """ + + def validate(self, input, other): + assert input.dim() == 1 and other.dim() == 1 + + def no_split(self): + name = f'R = R dot R' + dim_partition_dict = {"input": {}, "other": {}, "output": {}} + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_one_dim(self, mesh_dim): + name = f'S{mesh_dim} = S{mesh_dim} dot S{mesh_dim}' + dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}} + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim]) + + def generate(self) -> List[IntermediateStrategy]: + strategy_list = [] + + # do not split dimensions for dot product + # R = R dot R + strategy_list.append(self.no_split()) + + # split two tensors in the same dimensions + # S = S dot S + strategy_list.append(self.split_one_dim(0)) + strategy_list.append(self.split_one_dim(1)) + + return strategy_list + + +class MatVecStrategyGenerator(StrategyGenerator): + + def validate(self, input, other) -> bool: + assert input.dim() > 1 and other.dim() == 1 + + def no_split(self): + name = "R = R x R" + dim_partition_dict = {"input": {}, "other": {}, "output": {}} + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_input_batch(self, mesh_dim): + name = f'S{mesh_dim}R = S{mesh_dim}R x R' + dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}} + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def generate(self) -> List[IntermediateStrategy]: + strategy_list = [] + + # no split + strategy_list.append(self.no_split()) + + # split the batch dim for the first tensor only + strategy_list.append(self.split_input_batch(0)) + strategy_list.append(self.split_input_batch(1)) + + return strategy_list + + class MatMulStrategyGenerator(StrategyGenerator): - # TODO: to be implmented - pass + """ + MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is + a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm. + + A matmul can be formulated as [n, p] x [p, q] = [n, q] + + Args: + is_linear (bool): whether this generator is used for nn.Linear and F.linear. + This will incur extra transformation of the dim partitioning as the weight is transposed. + """ + + def __init__(self, is_linear: bool, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_linear = is_linear + + # as the weight for the linear module is transposed, we can compute + # the correponding dimension indexfor convenience + if is_linear: + self.dim_q = 0 + self.dim_p = 1 + else: + self.dim_q = 1 + self.dim_p = 0 + + def validate(self, input, other, bias) -> bool: + # make sure the second tensor is a 2D tensor + assert input.dim() > 0 and other.dim() == 2 + + # make sure bias is of the same dimension + if self.is_linear: + assert bias is None or bias.shape[-1] == other.shape[0] + else: + assert bias is None or bias.shape[-1] == other.shape[1] + + def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): + # handle case SS = SR x RS + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + + dim_partition_dict = { + "input": { + 0: [mesh_dim_0] + }, + "other": { + self.dim_q: [mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + # handle the case SR = SS x SR + name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + "other": { + self.dim_p: [mesh_dim_1] + }, + "bias": {}, + "output": { + 0: [mesh_dim_0] + }, + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1]) + + def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + dim_partition_dict = { + "input": { + -1: [mesh_dim_0] + }, + "other": { + self.dim_p: [mesh_dim_0], + self.dim_q: [mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_1] + }, + "output": { + -1: [mesh_dim_1] + }, + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def recompute_split_both_contract(self, mesh_dim): + name = f'RR = RS{mesh_dim} x S{mesh_dim}R' + dim_partition_dict = { + "input": { + -1: [mesh_dim] + }, + "other": { + self.dim_p: [mesh_dim] + }, + "bias": {}, + "output": {}, + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim]) + + def split_rhs_space_only(self, mesh_dim): + name = f'RS{mesh_dim} = RR x RS{mesh_dim}' + dim_partition_dict = { + "input": {}, + "other": { + self.dim_q: [mesh_dim] + }, + "bias": { + -1: [mesh_dim] + }, + "output": { + -1: [mesh_dim] + }, + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim]) + + def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": {}, + "bias": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1] + }, + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + dim_partition_dict = { + "input": { + -1: [mesh_dim_0, mesh_dim_1] + }, + "other": { + self.dim_p: [mesh_dim_0, mesh_dim_1] + }, + "bias": {}, + "output": {}, + } + return IntermediateStrategy(name=name, + dim_partition_dict=dim_partition_dict, + all_reduce_axis=[mesh_dim_0, mesh_dim_1]) + + def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + + dim_partition_dict = { + "input": {}, + "other": { + self.dim_q: [mesh_dim_0, mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_0, mesh_dim_1] + }, + "output": { + -1: [mesh_dim_0, mesh_dim_1] + }, + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) class BatchedMatMulStrategyGenerator(StrategyGenerator): @@ -30,6 +259,15 @@ class BatchedMatMulStrategyGenerator(StrategyGenerator): super().__init__(*args, **kwargs) self.is_torch_bmm = is_torch_bmm + def validate(self, input, other, bias) -> bool: + if self.is_torch_bmm: + assert input.shape == other.shape + assert input.dim() > 2 + assert other.shape[-1] == bias.shape[0] + else: + # TODO: validate these inputs are broadcastable + pass + def split_one_batch_dim(self): if 1 in self.device_mesh.mesh_shape: mesh_dim = self.device_mesh.mesh_shape.index(1) diff --git a/colossalai/auto_parallel/solver/op_handler/strategy_generator.py b/colossalai/auto_parallel/solver/op_handler/strategy_generator.py index fce47f5ce..4e39fcd8e 100644 --- a/colossalai/auto_parallel/solver/op_handler/strategy_generator.py +++ b/colossalai/auto_parallel/solver/op_handler/strategy_generator.py @@ -32,4 +32,14 @@ class StrategyGenerator(ABC): @abstractmethod def generate(self) -> List[IntermediateStrategy]: - pass \ No newline at end of file + """ + """ + pass + + @abstractmethod + def validate(self, *args, **kwargs) -> bool: + """ + Validate if the operands are of desired shape. + If True, means this generator can be used for the current operation. + """ + pass