From 1cce6e36cafde6b7ca1292655d740aef2d38ed2c Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 20 Dec 2022 10:31:22 +0800 Subject: [PATCH] [autoparallel] use metainfo in handler (#2149) --- .../meta_profiler/meta_registry/activation.py | 2 +- .../meta_profiler/meta_registry/conv.py | 7 +- .../meta_profiler/meta_registry/linear.py | 10 ++- .../meta_profiler/meta_registry/norm.py | 2 +- .../meta_profiler/meta_registry/pooling.py | 2 +- .../node_handler/batch_norm_handler.py | 8 ++- .../binary_elementwise_handler.py | 12 +--- .../tensor_shard/node_handler/conv_handler.py | 8 +-- .../node_handler/linear_handler.py | 14 ++-- .../tensor_shard/node_handler/node_handler.py | 67 +++++++++++++++++++ .../node_handler/normal_pooling_handler.py | 4 +- 11 files changed, 105 insertions(+), 31 deletions(-) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index dc62005f0..7b2f8dfa4 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -28,7 +28,7 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs """ - input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + input_tensor = args[0].data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data inplace = kwargs.get("inplace", False) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py index f7d55529f..fd6c5184a 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py @@ -58,9 +58,12 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L """ has_bias: bool = False - input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + input_tensor = args[0].data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data - weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM] + if len(args) == 4: + weight_tensors = [args[1].data, args[3].data] + else: + weight_tensors = [args[1].data] # check if conv has bias if len(weight_tensors) > 1: diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py index b48748fa9..bb7935d0f 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py @@ -66,9 +66,13 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L """ has_bias: bool = False - input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data - output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data - weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM] + + input_tensor = args[0].data + output_tensor = args[2].data + if len(args) == 4: + weight_tensors = [args[1].data, args[3].data] + else: + weight_tensors = [args[1].data] # process the dimension of input and output if len(input_tensor.shape) > 2: diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py index 395eecdbb..b88bed88b 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py @@ -45,7 +45,7 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs """ - input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + input_tensor = args[0].data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data weight_tensor = next(filter(lambda x: x.name == "weight", args)).data bias_tensor = next(filter(lambda x: x.name == "bias", args)).data diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py index 63f321519..1c04bdc73 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py @@ -30,7 +30,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs """ - input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + input_tensor = args[0].data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data # construct forward args for flop mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py index 6bdd15d16..57b623b01 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py @@ -2,8 +2,10 @@ from typing import Dict, List import torch -from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import ModuleHandler +from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo + +from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector +from .node_handler import MetaInfoModuleHandler, ModuleHandler from .registry import operator_registry from .strategy import BatchNormStrategyGenerator, StrategyGenerator @@ -13,7 +15,7 @@ __all__ = ['BatchNormModuleHandler'] @operator_registry.register(torch.nn.BatchNorm1d) @operator_registry.register(torch.nn.BatchNorm2d) @operator_registry.register(torch.nn.BatchNorm3d) -class BatchNormModuleHandler(ModuleHandler): +class BatchNormModuleHandler(MetaInfoModuleHandler): """ A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module. """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py index 5b600e735..f510f7477 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -3,18 +3,12 @@ from typing import Dict, List, Union import torch from torch.fx.node import Node -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - OperationData, - OperationDataType, - ShardingStrategy, -) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager from ..constants import BCAST_FUNC_OP from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape -from .node_handler import NodeHandler +from .node_handler import MetaInfoNodeHandler, NodeHandler from .registry import operator_registry from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator @@ -22,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler'] @operator_registry.register(BCAST_FUNC_OP) -class BinaryElementwiseHandler(NodeHandler): +class BinaryElementwiseHandler(MetaInfoNodeHandler): """ An BinaryBcastOpHandler is a node handler which deals with operations which have two operands and broadcasting occurs such as torch.add. diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py index 0c00160ef..272b1c856 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py @@ -3,9 +3,9 @@ from typing import Dict, List import torch import torch.nn.functional as F -from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector from ..utils import transpose_partition_dim -from .node_handler import ModuleHandler, NodeHandler +from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler from .registry import operator_registry from .strategy import ConvStrategyGenerator, StrategyGenerator @@ -15,7 +15,7 @@ __all__ = ['ConvModuleHandler', 'ConvFunctionHandler'] @operator_registry.register(torch.nn.Conv1d) @operator_registry.register(torch.nn.Conv2d) @operator_registry.register(torch.nn.Conv3d) -class ConvModuleHandler(ModuleHandler): +class ConvModuleHandler(MetaInfoModuleHandler): """ A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module. """ @@ -63,7 +63,7 @@ class ConvModuleHandler(ModuleHandler): @operator_registry.register(F.conv1d) @operator_registry.register(F.conv2d) @operator_registry.register(F.conv3d) -class ConvFunctionHandler(NodeHandler): +class ConvFunctionHandler(MetaInfoNodeHandler): """ A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions. """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py index d8e3ce6a5..37ff3c3ab 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -3,12 +3,16 @@ from typing import Dict, List, Union import torch import torch.nn.functional as F -from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim +from colossalai.auto_parallel.tensor_shard.utils import ( + check_sharding_spec_validity, + transpose_partition_dim, + update_partition_dim, +) from colossalai.logging import get_dist_logger from colossalai.tensor.sharding_spec import ShardingNotDivisibleError -from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy -from .node_handler import ModuleHandler, NodeHandler +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector +from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler from .registry import operator_registry from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator @@ -139,7 +143,7 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha @operator_registry.register(torch.nn.Linear) -class LinearModuleHandler(ModuleHandler): +class LinearModuleHandler(MetaInfoModuleHandler): """ A LinearModuleHandler which deals with the sharding strategies for nn.Linear module. """ @@ -199,7 +203,7 @@ class LinearModuleHandler(ModuleHandler): @operator_registry.register(F.linear) -class LinearFunctionHandler(NodeHandler): +class LinearFunctionHandler(MetaInfoNodeHandler): """ A LinearFunctionHandler which deals with the sharding strategies for F.Linear. """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 27957ca63..6d603f63e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -4,6 +4,7 @@ from typing import Dict, List, Tuple, Union import torch from torch.fx.node import Node +from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, @@ -133,6 +134,26 @@ class NodeHandler(ABC): strategy.resharding_costs = resharding_costs return strategy + def get_target_function(self) -> callable: + """ + This function is used to get the target function for the node handler. + The target function is used to analyze the costs of strategies. + """ + if self.node.op in ('placeholder', 'get_attr', 'output'): + return None + + if self.node.op == 'call_module': + submod = self.node.graph.owning_module.get_submodule(self.node.target) + target = type(submod) + elif self.node.op == 'call_function': + target = self.node.target + elif self.node.op == 'call_method': + target = getattr(self.node.args[0]._meta_data.__class__, self.node.target) + else: + raise ValueError(f'Unsupported node type: {self.node.op}') + + return target + def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector: """ Register different sharding strategies for the current node. @@ -204,6 +225,29 @@ class NodeHandler(ABC): pass +class MetaInfoNodeHandler(NodeHandler): + """ + This is a base class to handle the nodes patched in the meta profiler. + + Note: this class will be integrated into the NodeHandler class in the future, after + all the functions are patched. + """ + + def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector: + """ + This method is inherited from NodeHandler. It will register the strategies first, + and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class. + """ + super().register_strategy(compute_resharding_cost=compute_resharding_cost) + target = self.get_target_function() + for strategy in self.strategies_vector: + metainfo = MetaInfo(strategy, target) + strategy.compute_cost = metainfo.compute_cost + strategy.memory_cost = metainfo.memory_cost + + return self.strategies_vector + + class ModuleHandler(NodeHandler): def __init__(self, *args, **kwargs) -> None: @@ -221,3 +265,26 @@ class ModuleHandler(NodeHandler): self.module = module self.named_parameters = named_parameters self.named_buffers = named_buffers + + +class MetaInfoModuleHandler(ModuleHandler): + """ + This is a base class to handle the module patched in the meta profiler. + + Note: this class will be integrated into the ModuleHandler class in the future, after + all the modules are patched. + """ + + def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector: + """ + This method is inherited from NodeHandler. It will register the strategies first, + and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class. + """ + super().register_strategy(compute_resharding_cost=compute_resharding_cost) + target = self.get_target_function() + for strategy in self.strategies_vector: + metainfo = MetaInfo(strategy, target) + strategy.compute_cost = metainfo.compute_cost + strategy.memory_cost = metainfo.memory_cost + + return self.strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py index 1509c05a3..4e71ccba9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py @@ -3,7 +3,7 @@ from typing import Dict, List import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import ModuleHandler +from .node_handler import MetaInfoModuleHandler, ModuleHandler from .registry import operator_registry from .strategy import NormalPoolStrategyGenerator, StrategyGenerator @@ -16,7 +16,7 @@ __all__ = ['NormPoolingHandler'] @operator_registry.register(torch.nn.AvgPool1d) @operator_registry.register(torch.nn.AvgPool2d) @operator_registry.register(torch.nn.AvgPool3d) -class NormPoolingHandler(ModuleHandler): +class NormPoolingHandler(MetaInfoModuleHandler): """ A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module. """