diff --git a/colossalai/auto_parallel/solver/op_handler/dot_handler.py b/colossalai/auto_parallel/solver/op_handler/dot_handler.py index 26791df46..38b3be16f 100644 --- a/colossalai/auto_parallel/solver/op_handler/dot_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/dot_handler.py @@ -1,15 +1,168 @@ import operator import torch +import torch.nn as nn +import torch.nn.functional as F from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler +from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP from functools import reduce +from enum import Enum +from .strategy_generator import StrategyGenerator, IntermediateStrategy +from typing import List __all__ = ['DotHandler'] +class MatMulStrategyGenerator(StrategyGenerator): + # TODO: to be implmented + pass + + +class BatchedMatMulStrategyGenerator(StrategyGenerator): + """ + Generate sharding strategies for the batched matrix multiplication. + + A batched matrix multiplication can be viewed as + [b, i, k] x [b, k, j] -> [b, i, j] + """ + + def __init__(self, is_torch_bmm: bool, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_torch_bmm = is_torch_bmm + + def split_one_batch_dim(self): + if 1 in self.device_mesh.mesh_shape: + mesh_dim = self.device_mesh.mesh_shape.index(1) + name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' + dim_partition_dict = { + "input": { + 0: [mesh_dim] + }, + "other": { + 0: [mesh_dim] + }, + "bias": {}, + "output": { + 0: [mesh_dim] + } + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + else: + return None + + def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "bias": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1] + } + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_one_batch_dim(self, mesh_dim): + name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' + dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}} + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0], + -2: [mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0] + }, + "bias": {}, + "output": { + 0: mesh_dim_0, + -2: [mesh_dim_1] + } + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0] + }, + "other": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + } + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0], + -2: [mesh_dim_1] + }, + "bias": {}, + "output": { + 0: [mesh_dim_0], + -2: [mesh_dim_1] + } + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1]) + + def generate(self) -> List[IntermediateStrategy]: + strategy_list = [] + + # split only the batch dimension + # Sb = Sb x Sb + # can be None as it is only for 1D device mesh + strategy = self.split_one_batch_dim() + if strategy: + strategy_list.append(strategy) + + # split batch dim of two inputs and the i dim of the first tensor + # SbSi = SbSi x Sb + strategy_list.append(self.split_batch_dim_lhs_space(0, 1)) + strategy_list.append(self.split_batch_dim_lhs_space(1, 0)) + + # split batch dim of two inputs and the j of the second tensor + # SbSj = Sb x SbSj + strategy_list.append(self.split_batch_dim_rhs_space(0, 1)) + strategy_list.append(self.split_batch_dim_rhs_space(1, 0)) + + # split batch dim of two inputs and the k dim of two inputs + # Sb = SbSk x SbSk, need to all-reduce by k dim + strategy_list.append(self.split_batch_dim_both_contract(0, 1)) + strategy_list.append(self.split_batch_dim_both_contract(1, 0)) + + # split two batch dim + strategy_list.append(self.split_two_batch_dim(0, 1)) + strategy_list.append(self.split_two_batch_dim(1, 0)) + + return strategy_list + + class DotHandler(OperatorHandler): """ - A OperatorHandler which deals with the sharding strategies of linear matrix multiplication. + A OperatorHandler which deals with the sharding strategies for nn.Linear and F.linear. """ def __init__(self, *args, **kwargs): @@ -297,7 +450,7 @@ class DotHandler(OperatorHandler): def register_strategy(self) -> StrategiesVector: ''' - Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector. + Generate every possible strategies for a linear node, and record all strategies into the strategies_vector. Output: diff --git a/colossalai/auto_parallel/solver/op_handler/operator_handler.py b/colossalai/auto_parallel/solver/op_handler/operator_handler.py index 8db91ffef..52899c742 100644 --- a/colossalai/auto_parallel/solver/op_handler/operator_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/operator_handler.py @@ -5,7 +5,6 @@ from abc import ABC, abstractmethod from torch.fx.node import Node from typing import Dict, List from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec from .._utils import generate_resharding_costs, generate_sharding_spec from colossalai.auto_parallel.solver.constants import * diff --git a/colossalai/auto_parallel/solver/op_handler/strategy_generator.py b/colossalai/auto_parallel/solver/op_handler/strategy_generator.py new file mode 100644 index 000000000..fce47f5ce --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/strategy_generator.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass +from abc import ABC, abstractmethod +from typing import List, Dict +from colossalai.device.device_mesh import DeviceMesh + +__all__ = ['IntermediateStrategy', 'StrategyGenerator'] + + +@dataclass +class IntermediateStrategy: + """ + IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is + to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler. + + Args: + name (str): name of the sharding strategy. + dim_partition_dict (Dict[Dict]): stores the tensor to dim partition dict mapping. + all_reduce_dims (List[int]): stores the dimensions which require an all-reduce operation. + """ + name: str + dim_partition_dict: Dict[str, Dict[int, List[int]]] + all_reduce_axis: List[int] = None + + +class StrategyGenerator(ABC): + """ + StrategyGenerator is used to generate the same group of sharding strategies. + """ + + def __init__(self, device_mesh: DeviceMesh): + self.device_mesh = device_mesh + + @abstractmethod + def generate(self) -> List[IntermediateStrategy]: + pass \ No newline at end of file