diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py b/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py index 378a14d03..a72d97554 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py @@ -1,13 +1,15 @@ -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -import torch -from torch.fx.node import Node -from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.device.device_mesh import DeviceMesh -from typing import Union, Dict, List, Optional -import warnings -from functools import reduce import functools import operator +import warnings +from functools import reduce +from typing import Dict, List, Optional, Union + +import torch +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec +from torch.fx.node import Node + from .constants import INFINITY_COST @@ -87,7 +89,7 @@ def generate_resharding_costs(nodes: List[Node], return resharding_costs -def exception_handler(func): +def ignore_sharding_exception(func): """ A function wrapper which executes the function with a specified seed. """ diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py index 76de2d149..519436270 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py @@ -1,9 +1,12 @@ import operator from functools import reduce + import torch -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ + ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) + from .operator_handler import OperatorHandler -from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler __all__ = ['BatchNormHandler'] @@ -110,7 +113,7 @@ class BatchNormHandler(OperatorHandler): return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation - @exception_handler + @ignore_sharding_exception def split_input_channel(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' @@ -185,7 +188,7 @@ class BatchNormHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' @@ -226,7 +229,7 @@ class BatchNormHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def non_split(self, mesh_dim_0, mesh_dim_1): name = f'RR = RR x R' @@ -322,7 +325,7 @@ class BatchNormHandler(OperatorHandler): new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name) self.strategies_vector.append(new_sharding_strategy) - @exception_handler + @ignore_sharding_exception def split_input_batch(self, mesh_dim_0): name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' @@ -363,7 +366,7 @@ class BatchNormHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' @@ -404,7 +407,7 @@ class BatchNormHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py index eca6bed42..6ac6dce76 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py @@ -1,14 +1,18 @@ import operator -from functools import reduce import warnings +from copy import deepcopy +from functools import reduce +from typing import Dict, List + import torch -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from .operator_handler import OperatorHandler +from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception) +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec -from copy import deepcopy -from typing import Dict, List -from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding + +from .operator_handler import OperatorHandler __all__ = ['BcastOpHandler'] @@ -136,7 +140,7 @@ class BcastOpHandler(OperatorHandler): return output_sharding_spec_list - @exception_handler + @ignore_sharding_exception def _register_strategy(self, output_sharding_spec): dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input) @@ -171,7 +175,7 @@ class BcastOpHandler(OperatorHandler): ############################################## #used to generate strategies for torch.matmul# ############################################## - @exception_handler + @ignore_sharding_exception def _registry_no_split_strategies_for_matmul(self, dim_partition_dict_for_batch_dim): # this dim partition dict only describes the batch dimensions, but in this scenario, # matrix dimensions are fully replicated, so it do not need extra process. @@ -210,7 +214,7 @@ class BcastOpHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def _split_dim_i(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix): # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it. @@ -268,7 +272,7 @@ class BcastOpHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def _split_dim_k(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix): # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it. @@ -332,7 +336,7 @@ class BcastOpHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def _split_dim_j(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix): # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it. @@ -398,7 +402,7 @@ class BcastOpHandler(OperatorHandler): self._split_dim_k(dim_partition_dict, mesh_dim_list) self._split_dim_j(dim_partition_dict, mesh_dim_list) - @exception_handler + @ignore_sharding_exception def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): dim_partition_dict_for_lhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]} sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) @@ -435,7 +439,7 @@ class BcastOpHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): dim_partition_dict_for_lhs = {-1: [mesh_dim_0]} sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) @@ -474,7 +478,7 @@ class BcastOpHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): dim_partition_dict_for_lhs = {-2: [mesh_dim_0]} sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py index 1208f86d3..c41ca6370 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py @@ -1,10 +1,13 @@ import operator -from functools import reduce import warnings +from functools import reduce + import torch -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ + ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) + from .operator_handler import OperatorHandler -from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler __all__ = ['ConvHandler'] @@ -105,7 +108,7 @@ class ConvHandler(OperatorHandler): return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight - @exception_handler + @ignore_sharding_exception def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' @@ -153,7 +156,7 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_input_batch(self, mesh_dim_0): name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' @@ -199,7 +202,7 @@ class ConvHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' @@ -245,7 +248,7 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' @@ -288,7 +291,7 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_input_in_channel_weight_in_channel(self, mesh_dim_0): name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' @@ -331,7 +334,7 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_weight_out_channel(self, mesh_dim_0): name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' @@ -374,7 +377,7 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def non_split(self): name = f'RR = RR x RR' @@ -415,7 +418,7 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' @@ -463,7 +466,7 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py index 549b58df8..4feeacd98 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py @@ -1,15 +1,18 @@ import operator +from enum import Enum +from functools import reduce +from typing import List + import torch import torch.nn as nn import torch.nn.functional as F -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from .operator_handler import OperatorHandler +from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ + ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) + from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP -from functools import reduce -from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler -from enum import Enum -from .strategy_generator import StrategyGenerator, IntermediateStrategy -from typing import List +from .operator_handler import OperatorHandler +from .strategy_generator import IntermediateStrategy, StrategyGenerator __all__ = ['DotHandler'] @@ -415,7 +418,7 @@ class DotHandler(OperatorHandler): compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size return compute_cost - @exception_handler + @ignore_sharding_exception def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): # handle case SS = SR x RS name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' @@ -456,7 +459,7 @@ class DotHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): # handle the case SR = SS x SR name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' @@ -496,7 +499,7 @@ class DotHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' @@ -534,7 +537,7 @@ class DotHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def recompute_split_both_contract(self, mesh_dim): name = f'RR = RS{mesh_dim} x S{mesh_dim}R' @@ -569,7 +572,7 @@ class DotHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_rhs_space_only(self, mesh_dim): name = f'RS{mesh_dim} = RR x RS{mesh_dim}' @@ -605,7 +608,7 @@ class DotHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' @@ -641,7 +644,7 @@ class DotHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' @@ -678,7 +681,7 @@ class DotHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py index 45c001b60..d01a487ad 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py @@ -1,14 +1,17 @@ import operator -from functools import reduce import warnings +from copy import deepcopy +from functools import reduce +from typing import Dict, List + import torch -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from .operator_handler import OperatorHandler +from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ + ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec -from copy import deepcopy -from typing import Dict, List -from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler + +from .operator_handler import OperatorHandler __all__ = ['EmbeddingHandler'] @@ -76,7 +79,7 @@ class EmbeddingHandler(OperatorHandler): return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight - @exception_handler + @ignore_sharding_exception def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1): name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}' @@ -117,7 +120,7 @@ class EmbeddingHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR' diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py index 0d28875c7..c75fdbbb6 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py @@ -1,9 +1,13 @@ import operator from functools import reduce + import torch -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + generate_sharding_size, ignore_sharding_exception) +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) + from .operator_handler import OperatorHandler -from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_2d_sharding, enumerate_all_possible_1d_sharding, generate_sharding_size __all__ = ['LayerNormHandler'] @@ -149,21 +153,21 @@ class LayerNormHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) - @exception_handler + @ignore_sharding_exception def split_input_batch_single_mesh_dim(self, mesh_dim_0): batch_dimension_length = self.input_data.dim() - self.weight.dim() dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length) for dim_partition in dim_partition_list: self._generate_strategy_with_dim_partition(dim_partition) - @exception_handler + @ignore_sharding_exception def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1): batch_dimension_length = self.input_data.dim() - self.weight.dim() dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length) for dim_partition in dim_partition_list: self._generate_strategy_with_dim_partition(dim_partition) - @exception_handler + @ignore_sharding_exception def non_split(self): name = f'RR = RR x R' diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py index 2fc619c52..2d3967025 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py @@ -1,14 +1,17 @@ import colorsys -from .operator_handler import OperatorHandler -from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from copy import deepcopy import math -from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler import warnings +from copy import deepcopy + import torch +from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ + ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + from ..constants import INFINITY_COST +from .operator_handler import OperatorHandler class ReshapeHandler(OperatorHandler): @@ -24,7 +27,7 @@ class ReshapeHandler(OperatorHandler): def _generate_compute_cost(self, *args, **kwargs): return super()._generate_compute_cost(*args, **kwargs) - @exception_handler + @ignore_sharding_exception def register_strategy(self): # TODO: add strategies with more output sharding specs other than only fully replicated. input_node = self.strategies_vector.predecessor_nodes[0] diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py index 57ad9e262..c929d2fad 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py @@ -1,16 +1,20 @@ +import math import operator -from functools import reduce import warnings +from copy import deepcopy +from functools import reduce +from typing import Dict, List + import torch -from colossalai.auto_parallel.tensor_shard.deprecated.constants import INFINITY_COST -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from .operator_handler import OperatorHandler +from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ + ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.constants import \ + INFINITY_COST +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec -from copy import deepcopy -from typing import Dict, List -import math -from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler + +from .operator_handler import OperatorHandler __all__ = ['UnaryElementwiseHandler'] @@ -40,7 +44,7 @@ class UnaryElementwiseHandler(OperatorHandler): def _generate_compute_cost(self, *args, **kwargs): return super()._generate_compute_cost(*args, **kwargs) - @exception_handler + @ignore_sharding_exception def register_strategy(self): # TODO: integrate element-wise func and module together # create sharding strategy for element-wise function diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py index bd97e2736..6991e913d 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py @@ -6,12 +6,10 @@ from typing import Dict, List import torch -from colossalai.auto_parallel.tensor_shard.deprecated._utils import ( - enumerate_all_possible_1d_sharding, - enumerate_all_possible_2d_sharding, - exception_handler, -) -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception) +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec @@ -146,7 +144,7 @@ class WhereHandler(OperatorHandler): return output_sharding_spec_list - @exception_handler + @ignore_sharding_exception def _register_strategy(self, output_sharding_spec): dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict sharding_spec_for_condition = self._generate_sharding_spec(self.condition_data, dim_partition_dict_for_input) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/dot_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/dot_handler.py index ad90e4e5b..3de03f440 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/dot_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/dot_handler.py @@ -5,7 +5,8 @@ import torch import torch.nn.functional as F from colossalai.auto_parallel.tensor_shard.utils import (switch_partition_dim, update_partition_dim) -from colossalai.tensor.sharding_spec import ShardingException +from colossalai.logging import get_dist_logger +from colossalai.tensor.sharding_spec import ShardingNotDivisibleError from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy) from .node_handler import ModuleHandler, NodeHandler @@ -15,6 +16,100 @@ from .strategy import (BatchedMatMulStrategyGenerator, LinearProjectionStrategyG __all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler'] +def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy, + weight_name: str) -> ShardingStrategy: + """ + This function is a helper function used by both module node handler and function node handler. This function will + convert the sharding spec for the transposed weight to the correct partititon spec. + + Args: + strategy (ShardingStrategy): the strategy generated by the strategy generator. + weight_name (str): the name of the OperationData object for the weight. + """ + # switch the dimensions of the transposed weight + sharding_spec = strategy.get_sharding_spec_by_name(weight_name) + op_data = strategy.get_op_data_by_name(weight_name) + assert op_data.logical_shape != op_data.data.shape, \ + "Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same" + switch_partition_dim(sharding_spec, 0, -1) + return strategy + + +def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: ShardingStrategy, input_name: str, + output_name: str) -> List[ShardingStrategy]: + """ + This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output + should have the same sharding spec. + + Args: + strategy (ShardingStrategy): the logical strategy generated by the strategy generator. + input_name (str): the name of the OperationData object for the input. + output_name (str): the name of the OperationData object for the output. + + + """ + # the result will be a list of strategies + sharding_strategies = [] + + # get operation data + input_op_data = strategy.get_op_data_by_name(input_name) + output_op_data = strategy.get_op_data_by_name(output_name) + input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name) + + # get logger for debug message + logger = get_dist_logger() + + # for the input of the linear operation, it can be multi-dimensional. The sharding spec generated is only + # 2D, where the first dimension is non-matrix dimension and the last dimension is the matrix dimension. + # the logical non-matrix dimension can belong to the 0th to (N-1)th dimension of the physical input shape. + # Thus, we enumerate to get all possible cases. + if 0 in input_sharding_spec.dim_partition_dict: + # if 0 is in the dim_partition_dict, it means that the + # the generated sharding strategy does shard the non-matrix dimension, + # in this case, we need to do enumeration + num_input_dims = input_op_data.data.dim() + for i in range(num_input_dims - 1): + strategy_copy = strategy.clone() + input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) + try: + # replace the 0th dimension in the logical sharding with ith dimension in the physical sharding + update_partition_dim(sharding_spec=input_sharding_spec, + dim_mapping={0: i}, + physical_shape=input_op_data.data.shape, + inplace=True) + update_partition_dim(sharding_spec=output_sharding_spec, + dim_mapping={0: i}, + physical_shape=output_op_data.data.shape, + inplace=True) + sharding_strategies.append(strategy_copy) + except ShardingNotDivisibleError as e: + logger.debug( + f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}' + ) + else: + # the generated sharding strategy does not shard the non-matrix dimension, + # in this case, we don't need to do enumeration + # but instead, we still need to convert the logical shape to physical shape + strategy_copy = strategy.clone() + input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) + + # after updating, the logical shape will be replaced by the physical shape + update_partition_dim(sharding_spec=input_sharding_spec, + dim_mapping={}, + physical_shape=input_op_data.data.shape, + inplace=True) + update_partition_dim(sharding_spec=output_sharding_spec, + dim_mapping={}, + physical_shape=output_op_data.data.shape, + inplace=True) + print(input_op_data.data.shape) + print(output_op_data.data.shape) + sharding_strategies.append(strategy_copy) + return sharding_strategies + + @operator_registry.register(torch.nn.Linear) class LinearModuleHandler(ModuleHandler): """ @@ -58,44 +153,20 @@ class LinearModuleHandler(ModuleHandler): def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: """ - Convert the sharding spec from the logical shape to the physical shape. + Convert the sharding spec from the logical shape to the physical shape. In this function, two tasks are completed: + 1. the sharding spec is updated for the transposed weight + 2. the input and output sharding specs are updated to physical shape. """ # switch the dimensions of the transposed weight - for op_data, sharding_spec in strategy.input_sharding_specs.items(): - if op_data.name == "weight": - assert op_data.logical_shape != op_data.data.shape - switch_partition_dim(sharding_spec, 0, -1) + strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight') # create multiple sharding strategies for the inputs # as input can be multi-dimensinal and the partition dim is only 2D, # we need to map the partition at dim 0 to one of the first few dimensions of the input - sharding_strategies = [] - input_op_data = strategy.get_op_data_by_name(str(self.node.args[0])) - output_op_data = strategy.get_op_data_by_name(str(self.node)) - num_input_dims = input_op_data.data.dim() - input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name) - - if 0 in input_sharding_spec.dim_partition_dict: - for i in range(num_input_dims - 1): - new_strategy = strategy.clone() - input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name) - output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name) - try: - update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping={0: i}, - physical_shape=input_op_data.data.shape, - inplace=True) - update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping={0: i}, - physical_shape=output_op_data.data.shape, - inplace=True) - sharding_strategies.append(new_strategy) - except ShardingException: - pass - else: - sharding_strategies.append(strategy) - - return sharding_strategies + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, + input_name=str(self.node.args[0]), + output_name=str(self.node)) + return strategies @operator_registry.register(F.linear) @@ -113,9 +184,12 @@ class LinearFunctionHandler(NodeHandler): def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process + input_meta_data = self.node.args[0]._meta_data + input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape physical_input_operand = OperationData(name=str(self.node.args[0]), type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + data=self.node.args[0]._meta_data, + logical_shape=input_logical_shape) # check if the other operand is a parameter if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): @@ -144,44 +218,17 @@ class LinearFunctionHandler(NodeHandler): return mapping def post_process(self, strategy: ShardingStrategy): - """ - Convert the sharding spec of the weight parameter back to its original shape. - """ - for op_data, sharding_spec in strategy.input_sharding_specs.items(): - if op_data.name == str(self.node.args[1]): - assert op_data.logical_shape != op_data.data.shape - switch_partition_dim(sharding_spec, 0, -1) + # switch the dimensions of the transposed weight + strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, + weight_name=str(self.node.args[1])) # create multiple sharding strategies for the inputs # as input can be multi-dimensinal and the partition dim is only 2D, # we need to map the partition at dim 0 to one of the first few dimensions of the input - sharding_strategies = [] - input_op_data = strategy.get_op_data_by_name(str(self.node.args[0])) - output_op_data = strategy.get_op_data_by_name(str(self.node)) - num_input_dims = input_op_data.data.dim() - input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name) - - if 0 in input_sharding_spec.dim_partition_dict: - for i in range(num_input_dims - 1): - new_strategy = strategy.clone() - input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name) - output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name) - try: - update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping={0: i}, - physical_shape=input_op_data.data.shape, - inplace=True) - update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping={0: i}, - physical_shape=output_op_data.data.shape, - inplace=True) - sharding_strategies.append(new_strategy) - except ShardingException: - pass - else: - sharding_strategies.append(strategy) - - return strategy + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, + input_name=str(self.node.args[0]), + output_name=str(self.node)) + return strategies @operator_registry.register(torch.bmm) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py index 4c4a0c3ea..716ffc917 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py @@ -1,6 +1,7 @@ import copy import operator from functools import reduce +from typing import List from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.tensor.shape_consistency import CollectiveCommPattern @@ -292,7 +293,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - def generate(self): + def collate_strategies(self) -> List[ShardingStrategy]: ''' Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. ''' @@ -325,9 +326,4 @@ class BatchNormStrategyGenerator(StrategyGenerator): # S01R = S01R x R WITH SYNC_BN # strategy_list.append(self.split_input_batch_1d(0, 1)) - for strategy in strategy_list: - self.update_communication_cost(strategy) - self.update_compute_cost(strategy) - self.update_memory_cost(strategy) - return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py index fe40cc1a9..427eea671 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py @@ -5,7 +5,8 @@ from functools import reduce from typing import List from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) -from colossalai.auto_parallel.tensor_shard.utils import exception_handler +from colossalai.auto_parallel.tensor_shard.utils import \ + ignore_sharding_exception from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator @@ -25,8 +26,8 @@ class ConvStrategyGenerator(StrategyGenerator): For Conv3d, the dim of input data should be 5([N, C, H, W, D]). ''' input_op_data = self.op_data['input'] - assert input_op_data.dim() in (3, 4, - 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + assert input_op_data.data.dim() in ( + 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' def update_compute_cost(self, strategy: ShardingStrategy): ''' @@ -99,7 +100,7 @@ class ConvStrategyGenerator(StrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost - @exception_handler + @ignore_sharding_exception def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' @@ -146,7 +147,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - @exception_handler + @ignore_sharding_exception def split_input_batch(self, mesh_dim_0): name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' @@ -183,7 +184,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - @exception_handler + @ignore_sharding_exception def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' @@ -230,7 +231,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - @exception_handler + @ignore_sharding_exception def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' @@ -270,7 +271,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - @exception_handler + @ignore_sharding_exception def split_input_in_channel_weight_in_channel(self, mesh_dim_0): name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' @@ -301,7 +302,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - @exception_handler + @ignore_sharding_exception def split_weight_out_channel(self, mesh_dim_0): name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' @@ -334,7 +335,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - @exception_handler + @ignore_sharding_exception def non_split(self): name = f'RR = RR x RR' @@ -353,7 +354,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}) - @exception_handler + @ignore_sharding_exception def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' @@ -391,7 +392,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - @exception_handler + @ignore_sharding_exception def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' dim_partition_dict_mapping = { @@ -421,7 +422,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - @exception_handler + @ignore_sharding_exception def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' dim_partition_dict_mapping = { @@ -453,7 +454,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - def generate(self) -> List[ShardingStrategy]: + def collate_strategies(self) -> List[ShardingStrategy]: strategies = [] # SS = SR x RS strategies.append(self.split_input_batch_weight_out_channel(0, 1)) @@ -491,20 +492,4 @@ class ConvStrategyGenerator(StrategyGenerator): # RS01 = RR x RS01 strategies.append(self.split_1d_parallel_on_out_channel(0, 1)) - rm_list = [strategy for strategy in strategies if strategy is None] - for rm_element in rm_list: - strategies.remove(rm_element) - illegal_strategy_list = [] - # update mete info on cost - for strategy in strategies: - try: - self.update_communication_cost(strategy) - self.update_compute_cost(strategy) - self.update_memory_cost(strategy) - except AssertionError as e: - illegal_strategy_list.append(strategy) - warnings.warn(f'{e}') - for strategy in illegal_strategy_list: - strategies.remove(strategy) - return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py index dae168cbb..8b8080b75 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py @@ -1,4 +1,5 @@ import copy +from typing import List from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.tensor.shape_consistency import CollectiveCommPattern @@ -61,7 +62,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator): Deal with case 1 and 2. ''' - def generate(self): + def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] for strategy in self.predecessor_node.strategies_vector: dim_partition_dict_mapping = {} @@ -109,7 +110,7 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator): Deal with case 3. ''' - def generate(self): + def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] index = self.op_data["index"].data @@ -133,9 +134,4 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator): strategy_list.append(strategy) - for strategy in strategy_list: - self.update_communication_cost(strategy) - self.update_compute_cost(strategy) - self.update_memory_cost(strategy) - return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py index cf7530fa6..8c7d11437 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py @@ -1,6 +1,7 @@ import copy import operator from functools import reduce +from typing import List from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, @@ -159,7 +160,7 @@ class LayerNormGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - def generate(self): + def collate_strategies(self) -> List[ShardingStrategy]: ''' Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector. ''' @@ -178,11 +179,5 @@ class LayerNormGenerator(StrategyGenerator): # RR = RR x R strategy_list.append(self.non_split()) - # update mete info on cost - - for strategy in strategy_list: - self.update_communication_cost(strategy) - self.update_compute_cost(strategy) - self.update_memory_cost(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index 26fcacc57..175ef6631 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -3,6 +3,8 @@ from functools import reduce from typing import List from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.auto_parallel.tensor_shard.utils import \ + ignore_sharding_exception from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator @@ -169,7 +171,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): total=fwd_compute_cost + bwd_compute_cost) strategy.compute_cost = compute_cost - def generate(self) -> List[ShardingStrategy]: + def collate_strategies(self) -> List[ShardingStrategy]: strategies = [] # SS = SR x RS @@ -201,14 +203,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # RS01 = RR x RS01 strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) - # update mete info on cost - for strategy in strategies: - self.update_communication_cost(strategy) - self.update_compute_cost(strategy) - self.update_memory_cost(strategy) - return strategies + @ignore_sharding_exception def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): # handle case SS = SR x RS name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' @@ -249,6 +246,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): # handle the case SR = SS x SR name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' @@ -289,6 +287,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' @@ -324,6 +323,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def recompute_split_both_contract(self, mesh_dim): name = f'RR = RS{mesh_dim} x S{mesh_dim}R' @@ -351,6 +351,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_rhs_space_only(self, mesh_dim): name = f'RS{mesh_dim} = RR x RS{mesh_dim}' @@ -380,6 +381,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' # get sharding spec @@ -410,6 +412,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communcation_action_mapping) + @ignore_sharding_exception def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' @@ -437,6 +440,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' @@ -542,7 +546,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec=sharding_spec_mapping['bias'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mappingp['bias'] = bias_comm_spec + communication_action_mapping['bias'] = bias_comm_spec return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -662,7 +666,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - def generate(self) -> List[ShardingStrategy]: + def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] device_mesh_is_1d = True if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py index 59f6d89b5..457f51450 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py @@ -25,8 +25,8 @@ class NormalPoolStrategyGenerator(StrategyGenerator): For Pool3d, the dim of input data should be 5([N, C, H, W, D]). ''' input_op_data = self.op_data['input'] - assert input_op_data.dim() in (3, 4, - 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].' + assert input_op_data.data.dim() in ( + 3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].' def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem: ''' @@ -103,7 +103,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator): return dim_partition_list - def generate(self) -> List[ShardingStrategy]: + def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1) @@ -111,9 +111,4 @@ class NormalPoolStrategyGenerator(StrategyGenerator): strategy = self._generate_strategy_with_dim_partition(dim_partition) strategy_list.append(strategy) - for strategy in strategy_list: - self.update_communication_cost(strategy) - self.update_compute_cost(strategy) - self.update_memory_cost(strategy) - return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py index 3d58f0f11..de9dfba67 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py @@ -1,3 +1,5 @@ +from typing import List + from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from .strategy_generator import OutputStrategyGenerator @@ -30,7 +32,7 @@ class OutputGenerator(OutputStrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost - def generate(self): + def collate_strategies(self) -> List[ShardingStrategy]: dim_partition_dict_mapping = { "output": {}, } @@ -47,8 +49,4 @@ class OutputGenerator(OutputStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - self.update_communication_cost(strategy) - self.update_compute_cost(strategy) - self.update_memory_cost(strategy) - return [strategy] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py index a106488a8..9023ab0fb 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py @@ -1,3 +1,5 @@ +from typing import List + from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from .strategy_generator import StrategyGenerator @@ -35,7 +37,7 @@ class PlaceholderGenerator(StrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost - def generate(self): + def collate_strategies(self) -> List[ShardingStrategy]: dim_partition_dict_mapping = { "output": {}, } @@ -48,8 +50,4 @@ class PlaceholderGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - self.update_communication_cost(strategy) - self.update_compute_cost(strategy) - self.update_memory_cost(strategy) - return [strategy] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py index eb5636fc8..8fa5a8137 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py @@ -1,4 +1,5 @@ import copy +from typing import List from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.tensor.shape_consistency import CollectiveCommPattern @@ -49,7 +50,7 @@ class ReshapeGenerator(FollowingStrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost - def generate(self): + def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] # For reshape function, to keep the computing correctness we keep the sharding # spec of input is fully replicated. In addition, we will keep the output in diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py index 02ecbc9cc..6196e8336 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -4,13 +4,12 @@ from functools import reduce from typing import Any, Dict, List, Union import torch -from torch.fx import Node - from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, TrainCycleItem) from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec from colossalai.tensor.sharding_spec import ShardingSpec +from torch.fx import Node class StrategyGenerator(ABC): @@ -24,6 +23,9 @@ class StrategyGenerator(ABC): self.op_data = operation_data_mapping self.device_mesh = device_mesh + # validate the whether operation data is of desired value + self.validate() + @property def has_bias(self): """ @@ -102,9 +104,9 @@ class StrategyGenerator(ABC): comm_cost = TrainCycleItem(fwd=0, bwd=0, total=0) - def _compute_and_add(data: OperationData, comm_spec: CommSpec): + def _compute_and_add(op_data: OperationData, comm_spec: CommSpec): num_ele_in_comm = comm_spec.get_comm_cost() - dtype = operand.data.dtype + dtype = op_data.data.dtype size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() for phase, cost in num_ele_in_comm.items(): num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes @@ -151,11 +153,30 @@ class StrategyGenerator(ABC): size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() return reduce(operator.mul, sharded_shape) * size_per_elem_bytes - @abstractmethod def generate(self) -> List[ShardingStrategy]: """ Generate all possible sharding strategies for this operation. """ + strategies = self.collate_strategies() + + # some strategies may be None as ignore_sharding_exception may return None + # when ShardingSpecException occurs. + # thus, remove those None values + strategies = [strategy for strategy in strategies if strategy] + + # update the costs + # update mete info on cost + # these update methods are all in-place, the default method will do nothing + # the cost info will only be added if the child class overrides these methods + for strategy in strategies: + self.update_communication_cost(strategy) + self.update_compute_cost(strategy) + self.update_memory_cost(strategy) + + return strategies + + @abstractmethod + def collate_strategies(self) -> List[ShardingStrategy]: pass @abstractmethod diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py index ea582588b..b867a3068 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py @@ -1,4 +1,5 @@ import copy +from typing import List from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) @@ -48,7 +49,7 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost - def generate(self): + def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] # For element-wise function, we keep the sharding spec of output node same as # the input. Therefore, the different strategies of input node with same @@ -73,9 +74,4 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator): communication_action_mapping=communication_action_mapping) strategy_list.append(strategy) - for strategy in strategy_list: - self.update_communication_cost(strategy) - self.update_compute_cost(strategy) - self.update_memory_cost(strategy) - return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py index 48471cd73..95c8e2efa 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py @@ -1,4 +1,5 @@ import copy +from typing import List from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, @@ -78,7 +79,7 @@ class WhereGenerator(StrategyGenerator): return dim_partition_list - def generate(self): + def collate_strategies(self) -> List[ShardingStrategy]: ''' Generate every possible strategies for a where node, and record all strategies into the strategies_vector. ''' @@ -90,9 +91,4 @@ class WhereGenerator(StrategyGenerator): strategy = self._generate_strategy_with_dim_partition(dim_partition) strategy_list.append(strategy) - for strategy in strategy_list: - self.update_communication_cost(strategy) - self.update_compute_cost(strategy) - self.update_memory_cost(strategy) - return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py index bfa9c18f4..d56b80a09 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py @@ -1,12 +1,12 @@ from .broadcast import (BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape) from .factory import generate_resharding_costs, generate_sharding_spec -from .misc import exception_handler +from .misc import ignore_sharding_exception from .sharding import (enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size, switch_partition_dim, update_partition_dim) __all__ = [ 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', - 'generate_resharding_costs', 'generate_sharding_spec', 'exception_handler', 'switch_partition_dim', + 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding', 'generate_sharding_size' ] diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py index 7c8e530ff..d174988b8 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/misc.py +++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py @@ -1,16 +1,19 @@ import functools -import warnings -__all__ = ['exception_handler'] +from colossalai.logging import get_dist_logger +from colossalai.tensor.sharding_spec import ShardingSpecException + +__all__ = ['ignore_sharding_exception'] -def exception_handler(func): +def ignore_sharding_exception(func): """ - A function wrapper to handle the AssertionError in the function. + A function wrapper to handle the ShardingSpecException in the function. + If ShardingSpecException occurs, this function will return None. Usage: # mute the assertion error in the function - @exception_handler + @ignore_sharding_exception def do_something(): ... """ @@ -18,9 +21,11 @@ def exception_handler(func): @functools.wraps(func) def wrapper(*args, **kwargs): try: + logger = get_dist_logger() rst = func(*args, **kwargs) return rst - except AssertionError as e: - warnings.warn(f'{e}') + except ShardingSpecException as e: + logger.debug(e) + return None return wrapper diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index fe33baf65..fababb6e7 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -1,10 +1,12 @@ -import torch -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator +import operator from copy import deepcopy from enum import Enum from functools import reduce -import operator + +import torch + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator) __all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec'] @@ -138,7 +140,19 @@ class _DimSpec: return difference -class ShardingException(Exception): +class ShardingSpecException(Exception): + pass + + +class ShardingOutOfIndexError(ShardingSpecException): + pass + + +class DuplicatedShardingDimensionError(ShardingSpecException): + pass + + +class ShardingNotDivisibleError(ShardingSpecException): pass @@ -156,7 +170,11 @@ class ShardingSpec: sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. ''' - def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None): + def __init__(self, + device_mesh: DeviceMesh, + entire_shape: torch.Size, + dim_partition_dict=None, + sharding_sequence=None): self.device_mesh = device_mesh self.entire_shape = entire_shape self.dim_partition_dict = dim_partition_dict @@ -174,19 +192,36 @@ class ShardingSpec: return ' '.join(res_list) def _sanity_check(self): - ''' - In sanity check, we need make sure all axes in logical device mesh only be used - once. - ''' - dim_check_list = [i for i in range(self.device_mesh.logical_mesh_id.dim())] + # make sure all axes in logical device mesh only be used once + dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim())) for dim, shard_list in self.dim_partition_dict.items(): for element in shard_list: if element in dim_check_list: dim_check_list.remove(element) else: - raise ValueError( + raise DuplicatedShardingDimensionError( f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") + # make sure that the dimension is not out of index + for dim in self.dim_partition_dict.keys(): + if dim >= len(self.entire_shape): + raise ShardingOutOfIndexError( + f"The dim_partition_dict specifies to shard dimension {dim} but the entire_shape only has {len(self.entire_shape)} dimensions" + ) + + # make sure that the sharding for a dimension is divisible by the number of devices + for dim, shard_list in self.dim_partition_dict.items(): + tensor_dim_size = self.entire_shape[dim] + num_devices = 1 + + for element in shard_list: + num_devices *= self.device_mesh.mesh_shape[element] + + if tensor_dim_size % num_devices != 0: + raise ShardingNotDivisibleError( + f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.' + ) + def convert_dict_to_shard_sequence(self): ''' Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence. diff --git a/tests/test_auto_parallel/__init__.py b/tests/test_auto_parallel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_auto_parallel/test_tensor_shard/__init__.py b/tests/test_auto_parallel/test_tensor_shard/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py index f83d7ceb7..7adc211cf 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py @@ -1,12 +1,15 @@ -import torch -from torch.fx import GraphModule -import torch.nn as nn +from cProfile import run + import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor -from colossalai.fx.tracer.tracer import ColoTracer from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag class ConvModel(nn.Module): @@ -27,6 +30,7 @@ class ConvModel(nn.Module): return x +@run_on_environment_flag(name='AUTO_PARALLEL') def test_conv_handler(): physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py index 27120f0ba..426d179f1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py @@ -1,12 +1,13 @@ -import torch -from torch.fx import GraphModule -import torch.nn as nn import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor -from colossalai.fx.tracer.tracer import ColoTracer from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag class MatmulModel(nn.Module): @@ -20,6 +21,7 @@ class MatmulModel(nn.Module): return x +@run_on_environment_flag(name='AUTO_PARALLEL') def test_conv_handler(): physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/__init__.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/common.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/common.py new file mode 100644 index 000000000..695f79722 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/common.py @@ -0,0 +1,37 @@ +import torch + +from colossalai.tensor.sharding_spec import ShardingSpec + + +def is_sharding_spec_valid(sharding_spec: ShardingSpec, tensor: torch.Tensor): + """ + This function checks whether the ShardingSpec is valid for the physical tensor. + This check includes 2 items: + 1. the sharding spec covers all dimensions of the physical tensor + 2. the sharding spec for each dimension is divisible by the number of devices. + # + """ + # make sure all dims are covered in sharding spec + sharding_len = len(sharding_spec.sharding_sequence) + tensor_num_dim = tensor.dim() + num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0] + num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1] + assert sharding_len == tensor_num_dim, \ + f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).' + + # make sure the sharding is valid for each dim + for i in range(tensor_num_dim): + dim_size = tensor.shape[i] + dim_spec = sharding_spec.sharding_sequence[i] + + if str(dim_spec).startswith('S'): + devices_str = str(dim_spec).lstrip('S') + num_devices = 1 + + if '0' in devices_str: + num_devices *= num_devices_in_col + if '1' in devices_str: + num_devices *= num_devices_in_row + + assert dim_size >= num_devices and dim_size % num_devices == 0, \ + f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.' diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index fdd6a5198..8934571f9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -6,12 +6,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationDa StrategiesVector) from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.fx.tracer.meta_patch.patched_module import linear -from colossalai.tensor.sharding_spec import ShardingSpec +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.common import \ + is_sharding_spec_valid def test_linear_module_handler(): model = nn.Sequential(nn.Linear(16, 32).to('meta')) + tracer = ColoTracer() graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')}) gm = ColoGraphModule(model, graph) @@ -91,6 +92,12 @@ def test_linear_module_handler(): bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') output_sharding_spec = strategy.get_sharding_spec_by_name('_0') + # make sure the sharding spec is valid + is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16)) + is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('0.weight')) + is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('0.bias')) + is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32])) + # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] @@ -101,7 +108,7 @@ def test_linear_module_handler(): def test_linear_function_handler(): model = nn.Linear(16, 32).to('meta') tracer = ColoTracer() - graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')}) + graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')}) gm = ColoGraphModule(model, graph) physical_mesh_id = torch.arange(0, 4) @@ -117,11 +124,13 @@ def test_linear_function_handler(): # # check operation data mapping mapping = handler.get_operation_data_mapping() + print(mapping['input'].logical_shape) + assert mapping['input'].name == "input_1" assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 16]) + assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16]) assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 16]) + assert mapping['input'].logical_shape == torch.Size([16, 16]) assert mapping['other'].name == "weight" assert mapping['other'].data.is_meta @@ -137,7 +146,7 @@ def test_linear_function_handler(): assert mapping['output'].name == "linear" assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 32]) + assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32]) assert mapping['output'].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) @@ -167,11 +176,18 @@ def test_linear_function_handler(): for strategy in strategies_vector: strategy: ShardingStrategy + print(strategy) input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + # make sure the sharding spec is valid + is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16)) + is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('weight')) + is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('bias')) + is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32])) + # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py index 07f53a6cb..b27c0e412 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn - from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \ ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \ diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py index 96f7d8c0f..7aedb0d5e 100644 --- a/tests/test_tensor/test_sharded_linear.py +++ b/tests/test_tensor/test_sharded_linear.py @@ -1,16 +1,18 @@ +from functools import partial from lib2to3 import pgen2 -import colossalai -import torch + import pytest +import torch import torch.multiprocessing as mp import torch.nn.functional as F + +import colossalai +from colossalai.device.device_mesh import DeviceMesh +from colossalai.nn._ops._utils import gather_forward_split_backward +from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup +from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port -from functools import partial -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.tensor import ColoTensor, ColoParameter, ProcessGroup -from colossalai.nn._ops._utils import gather_forward_split_backward def run_dist(rank, world_size, port): @@ -18,7 +20,7 @@ def run_dist(rank, world_size, port): colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # create mlp vars - x = ColoTensor.from_torch_tensor(torch.rand(2, 4, 8, requires_grad=True)).cuda() + x = ColoTensor.from_torch_tensor(torch.rand(4, 4, 8, requires_grad=True)).cuda() w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda() b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda() diff --git a/tests/test_tensor/test_sharding_spec.py b/tests/test_tensor/test_sharding_spec.py index 1a84c8a27..909c84ef0 100644 --- a/tests/test_tensor/test_sharding_spec.py +++ b/tests/test_tensor/test_sharding_spec.py @@ -1,6 +1,7 @@ import torch -from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec + from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec def test_sharding_spec(): @@ -11,7 +12,7 @@ def test_sharding_spec(): # [8, 9, 10,11], # [12,13,14,15]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - entire_shape = torch.Size((4, 8, 6)) + entire_shape = torch.Size((16, 8, 6)) dim_partition_dict = {0: [0, 1]} # DistSpec: # shard_sequence: S01,R,R