diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index a5e3f649a..87bd8966b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -11,6 +11,7 @@ from .layer_norm_handler import LayerNormModuleHandler from .linear_handler import LinearFunctionHandler, LinearModuleHandler from .matmul_handler import MatMulHandler from .normal_pooling_handler import NormPoolingHandler +from .option import ShardOption from .output_handler import OutputHandler from .placeholder_handler import PlaceholderHandler from .registry import operator_registry @@ -27,5 +28,5 @@ __all__ = [ 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler', 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', - 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler' + 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 78dc58c90..fbab2b61e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -5,6 +5,7 @@ import torch from torch.fx.node import Node from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register +from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, @@ -35,12 +36,14 @@ class NodeHandler(ABC): node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, + shard_option: ShardOption = ShardOption.STANDARD, ) -> None: self.node = node self.predecessor_node = list(node._input_nodes.keys()) self.successor_node = list(node.users.keys()) self.device_mesh = device_mesh self.strategies_vector = strategies_vector + self.shard_option = shard_option def update_resharding_cost(self, strategy: ShardingStrategy) -> None: """ @@ -181,6 +184,21 @@ class NodeHandler(ABC): if op_data.data is not None and isinstance(op_data.data, torch.Tensor): check_sharding_spec_validity(sharding_spec, op_data.data) + remove_strategy_list = [] + for strategy in self.strategies_vector: + shard_level = 0 + for op_data, sharding_spec in strategy.sharding_specs.items(): + if op_data.data is not None and isinstance(op_data.data, torch.Tensor): + for dim, shard_axis in sharding_spec.dim_partition_dict.items(): + shard_level += len(shard_axis) + if self.shard_option == ShardOption.SHARD and shard_level == 0: + remove_strategy_list.append(strategy) + if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1: + remove_strategy_list.append(strategy) + + for strategy in remove_strategy_list: + self.strategies_vector.remove(strategy) + return self.strategies_vector def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/option.py b/colossalai/auto_parallel/tensor_shard/node_handler/option.py new file mode 100644 index 000000000..dffb0386d --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/option.py @@ -0,0 +1,17 @@ +from enum import Enum + +__all__ = ['ShardOption'] + + +class ShardOption(Enum): + """ + This enum class is to define the shard level required in node strategies. + + Notes: + STANDARD: We do not add any extra shard requirements. + SHARD: We require the node to be shard using at least one device mesh axis. + FULL_SHARD: We require the node to be shard using all device mesh axes. + """ + STANDARD = 0 + SHARD = 1 + FULL_SHARD = 2 diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py new file mode 100644 index 000000000..fda041110 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py @@ -0,0 +1,112 @@ +from functools import partial + +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption +from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import parameterize +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize + + +class LinearModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, others, bias=None): + x = nn.functional.linear(input, others, bias=bias) + return x + + +def check_shard_option(shard_option): + model = LinearModel().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + + tracer = ColoTracer() + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(4, 4, 4, 16).to('meta'), + 'others': torch.rand(32, 16).to('meta') + }) + gm = ColoGraphModule(model, graph) + linear_func_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(linear_func_node) + + # build handler + handler = LinearFunctionHandler(node=linear_func_node, + device_mesh=device_mesh, + strategies_vector=strategies_vector, + shard_option=shard_option) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # SS = SR x RS + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR_0' in strategy_name_list + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + if shard_option == ShardOption.SHARD: + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + if shard_option == ShardOption.STANDARD: + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +def test_shard_option(): + for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD]: + check_shard_option(shard_option) + + +if __name__ == '__main__': + test_shard_option()