mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 05:29:36 +00:00
[autoparallel] use metainfo in handler (#2149)
This commit is contained in:
parent
9b39170a5c
commit
1cce6e36ca
@ -28,7 +28,7 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis
|
|||||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
input_tensor = args[0].data
|
||||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||||
inplace = kwargs.get("inplace", False)
|
inplace = kwargs.get("inplace", False)
|
||||||
|
|
||||||
|
@ -58,9 +58,12 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
has_bias: bool = False
|
has_bias: bool = False
|
||||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
input_tensor = args[0].data
|
||||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||||
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
|
if len(args) == 4:
|
||||||
|
weight_tensors = [args[1].data, args[3].data]
|
||||||
|
else:
|
||||||
|
weight_tensors = [args[1].data]
|
||||||
|
|
||||||
# check if conv has bias
|
# check if conv has bias
|
||||||
if len(weight_tensors) > 1:
|
if len(weight_tensors) > 1:
|
||||||
|
@ -66,9 +66,13 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
has_bias: bool = False
|
has_bias: bool = False
|
||||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
|
||||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
input_tensor = args[0].data
|
||||||
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
|
output_tensor = args[2].data
|
||||||
|
if len(args) == 4:
|
||||||
|
weight_tensors = [args[1].data, args[3].data]
|
||||||
|
else:
|
||||||
|
weight_tensors = [args[1].data]
|
||||||
|
|
||||||
# process the dimension of input and output
|
# process the dimension of input and output
|
||||||
if len(input_tensor.shape) > 2:
|
if len(input_tensor.shape) > 2:
|
||||||
|
@ -45,7 +45,7 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
|
|||||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
input_tensor = args[0].data
|
||||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||||
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
|
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
|
||||||
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
|
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
|
||||||
|
@ -30,7 +30,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
|||||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
input_tensor = args[0].data
|
||||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||||
|
|
||||||
# construct forward args for flop mapping
|
# construct forward args for flop mapping
|
||||||
|
@ -2,8 +2,10 @@ from typing import Dict, List
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..sharding_strategy import OperationData, OperationDataType
|
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
|
||||||
from .node_handler import ModuleHandler
|
|
||||||
|
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
|
from .node_handler import MetaInfoModuleHandler, ModuleHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
from .strategy import BatchNormStrategyGenerator, StrategyGenerator
|
from .strategy import BatchNormStrategyGenerator, StrategyGenerator
|
||||||
|
|
||||||
@ -13,7 +15,7 @@ __all__ = ['BatchNormModuleHandler']
|
|||||||
@operator_registry.register(torch.nn.BatchNorm1d)
|
@operator_registry.register(torch.nn.BatchNorm1d)
|
||||||
@operator_registry.register(torch.nn.BatchNorm2d)
|
@operator_registry.register(torch.nn.BatchNorm2d)
|
||||||
@operator_registry.register(torch.nn.BatchNorm3d)
|
@operator_registry.register(torch.nn.BatchNorm3d)
|
||||||
class BatchNormModuleHandler(ModuleHandler):
|
class BatchNormModuleHandler(MetaInfoModuleHandler):
|
||||||
"""
|
"""
|
||||||
A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module.
|
A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module.
|
||||||
"""
|
"""
|
||||||
|
@ -3,18 +3,12 @@ from typing import Dict, List, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||||
CommAction,
|
|
||||||
CommType,
|
|
||||||
OperationData,
|
|
||||||
OperationDataType,
|
|
||||||
ShardingStrategy,
|
|
||||||
)
|
|
||||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
|
||||||
|
|
||||||
from ..constants import BCAST_FUNC_OP
|
from ..constants import BCAST_FUNC_OP
|
||||||
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
|
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
|
||||||
from .node_handler import NodeHandler
|
from .node_handler import MetaInfoNodeHandler, NodeHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
|
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
|
||||||
|
|
||||||
@ -22,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
|
|||||||
|
|
||||||
|
|
||||||
@operator_registry.register(BCAST_FUNC_OP)
|
@operator_registry.register(BCAST_FUNC_OP)
|
||||||
class BinaryElementwiseHandler(NodeHandler):
|
class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
||||||
"""
|
"""
|
||||||
An BinaryBcastOpHandler is a node handler which deals with operations which have two
|
An BinaryBcastOpHandler is a node handler which deals with operations which have two
|
||||||
operands and broadcasting occurs such as torch.add.
|
operands and broadcasting occurs such as torch.add.
|
||||||
|
@ -3,9 +3,9 @@ from typing import Dict, List
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
|
||||||
from ..utils import transpose_partition_dim
|
from ..utils import transpose_partition_dim
|
||||||
from .node_handler import ModuleHandler, NodeHandler
|
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
from .strategy import ConvStrategyGenerator, StrategyGenerator
|
from .strategy import ConvStrategyGenerator, StrategyGenerator
|
||||||
|
|
||||||
@ -15,7 +15,7 @@ __all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
|
|||||||
@operator_registry.register(torch.nn.Conv1d)
|
@operator_registry.register(torch.nn.Conv1d)
|
||||||
@operator_registry.register(torch.nn.Conv2d)
|
@operator_registry.register(torch.nn.Conv2d)
|
||||||
@operator_registry.register(torch.nn.Conv3d)
|
@operator_registry.register(torch.nn.Conv3d)
|
||||||
class ConvModuleHandler(ModuleHandler):
|
class ConvModuleHandler(MetaInfoModuleHandler):
|
||||||
"""
|
"""
|
||||||
A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module.
|
A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module.
|
||||||
"""
|
"""
|
||||||
@ -63,7 +63,7 @@ class ConvModuleHandler(ModuleHandler):
|
|||||||
@operator_registry.register(F.conv1d)
|
@operator_registry.register(F.conv1d)
|
||||||
@operator_registry.register(F.conv2d)
|
@operator_registry.register(F.conv2d)
|
||||||
@operator_registry.register(F.conv3d)
|
@operator_registry.register(F.conv3d)
|
||||||
class ConvFunctionHandler(NodeHandler):
|
class ConvFunctionHandler(MetaInfoNodeHandler):
|
||||||
"""
|
"""
|
||||||
A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions.
|
A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions.
|
||||||
"""
|
"""
|
||||||
|
@ -3,12 +3,16 @@ from typing import Dict, List, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim
|
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||||
|
check_sharding_spec_validity,
|
||||||
|
transpose_partition_dim,
|
||||||
|
update_partition_dim,
|
||||||
|
)
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
|
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
|
||||||
|
|
||||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
|
||||||
from .node_handler import ModuleHandler, NodeHandler
|
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
|
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
|
||||||
|
|
||||||
@ -139,7 +143,7 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|||||||
|
|
||||||
|
|
||||||
@operator_registry.register(torch.nn.Linear)
|
@operator_registry.register(torch.nn.Linear)
|
||||||
class LinearModuleHandler(ModuleHandler):
|
class LinearModuleHandler(MetaInfoModuleHandler):
|
||||||
"""
|
"""
|
||||||
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
||||||
"""
|
"""
|
||||||
@ -199,7 +203,7 @@ class LinearModuleHandler(ModuleHandler):
|
|||||||
|
|
||||||
|
|
||||||
@operator_registry.register(F.linear)
|
@operator_registry.register(F.linear)
|
||||||
class LinearFunctionHandler(NodeHandler):
|
class LinearFunctionHandler(MetaInfoNodeHandler):
|
||||||
"""
|
"""
|
||||||
A LinearFunctionHandler which deals with the sharding strategies for F.Linear.
|
A LinearFunctionHandler which deals with the sharding strategies for F.Linear.
|
||||||
"""
|
"""
|
||||||
|
@ -4,6 +4,7 @@ from typing import Dict, List, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
OperationData,
|
OperationData,
|
||||||
OperationDataType,
|
OperationDataType,
|
||||||
@ -133,6 +134,26 @@ class NodeHandler(ABC):
|
|||||||
strategy.resharding_costs = resharding_costs
|
strategy.resharding_costs = resharding_costs
|
||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
|
def get_target_function(self) -> callable:
|
||||||
|
"""
|
||||||
|
This function is used to get the target function for the node handler.
|
||||||
|
The target function is used to analyze the costs of strategies.
|
||||||
|
"""
|
||||||
|
if self.node.op in ('placeholder', 'get_attr', 'output'):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.node.op == 'call_module':
|
||||||
|
submod = self.node.graph.owning_module.get_submodule(self.node.target)
|
||||||
|
target = type(submod)
|
||||||
|
elif self.node.op == 'call_function':
|
||||||
|
target = self.node.target
|
||||||
|
elif self.node.op == 'call_method':
|
||||||
|
target = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported node type: {self.node.op}')
|
||||||
|
|
||||||
|
return target
|
||||||
|
|
||||||
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||||
"""
|
"""
|
||||||
Register different sharding strategies for the current node.
|
Register different sharding strategies for the current node.
|
||||||
@ -204,6 +225,29 @@ class NodeHandler(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MetaInfoNodeHandler(NodeHandler):
|
||||||
|
"""
|
||||||
|
This is a base class to handle the nodes patched in the meta profiler.
|
||||||
|
|
||||||
|
Note: this class will be integrated into the NodeHandler class in the future, after
|
||||||
|
all the functions are patched.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||||
|
"""
|
||||||
|
This method is inherited from NodeHandler. It will register the strategies first,
|
||||||
|
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
|
||||||
|
"""
|
||||||
|
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||||
|
target = self.get_target_function()
|
||||||
|
for strategy in self.strategies_vector:
|
||||||
|
metainfo = MetaInfo(strategy, target)
|
||||||
|
strategy.compute_cost = metainfo.compute_cost
|
||||||
|
strategy.memory_cost = metainfo.memory_cost
|
||||||
|
|
||||||
|
return self.strategies_vector
|
||||||
|
|
||||||
|
|
||||||
class ModuleHandler(NodeHandler):
|
class ModuleHandler(NodeHandler):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
@ -221,3 +265,26 @@ class ModuleHandler(NodeHandler):
|
|||||||
self.module = module
|
self.module = module
|
||||||
self.named_parameters = named_parameters
|
self.named_parameters = named_parameters
|
||||||
self.named_buffers = named_buffers
|
self.named_buffers = named_buffers
|
||||||
|
|
||||||
|
|
||||||
|
class MetaInfoModuleHandler(ModuleHandler):
|
||||||
|
"""
|
||||||
|
This is a base class to handle the module patched in the meta profiler.
|
||||||
|
|
||||||
|
Note: this class will be integrated into the ModuleHandler class in the future, after
|
||||||
|
all the modules are patched.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||||
|
"""
|
||||||
|
This method is inherited from NodeHandler. It will register the strategies first,
|
||||||
|
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
|
||||||
|
"""
|
||||||
|
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||||
|
target = self.get_target_function()
|
||||||
|
for strategy in self.strategies_vector:
|
||||||
|
metainfo = MetaInfo(strategy, target)
|
||||||
|
strategy.compute_cost = metainfo.compute_cost
|
||||||
|
strategy.memory_cost = metainfo.memory_cost
|
||||||
|
|
||||||
|
return self.strategies_vector
|
||||||
|
@ -3,7 +3,7 @@ from typing import Dict, List
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..sharding_strategy import OperationData, OperationDataType
|
from ..sharding_strategy import OperationData, OperationDataType
|
||||||
from .node_handler import ModuleHandler
|
from .node_handler import MetaInfoModuleHandler, ModuleHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
|
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ __all__ = ['NormPoolingHandler']
|
|||||||
@operator_registry.register(torch.nn.AvgPool1d)
|
@operator_registry.register(torch.nn.AvgPool1d)
|
||||||
@operator_registry.register(torch.nn.AvgPool2d)
|
@operator_registry.register(torch.nn.AvgPool2d)
|
||||||
@operator_registry.register(torch.nn.AvgPool3d)
|
@operator_registry.register(torch.nn.AvgPool3d)
|
||||||
class NormPoolingHandler(ModuleHandler):
|
class NormPoolingHandler(MetaInfoModuleHandler):
|
||||||
"""
|
"""
|
||||||
A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module.
|
A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module.
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user