From 6c331a5a097f78c7c81a9b1bf141a32bf150f643 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 14 Oct 2022 13:27:00 +0800 Subject: [PATCH] [autoparallel] refactored the autoparallel module for organization (#1706) * [autoparallel] refactored the autoparallel module for organization * polish code --- colossalai/auto_parallel/solver/__init__.py | 12 - .../{solver => tensor_shard}/constants.py | 0 .../node_handler/__init__.py | 17 +- .../node_handler/batch_norm_handler.py | 11 +- .../node_handler/conv_handler.py | 14 +- .../node_handler/dot_handler.py | 13 +- .../node_handler/getitem_handler.py | 14 +- .../node_handler/layer_norm_handler.py | 8 +- .../node_handler/node_handler.py | 11 +- .../node_handler/normal_pooling_handler.py | 11 +- .../node_handler/output_handler.py | 10 +- .../node_handler/placeholder_handler.py | 10 +- .../node_handler/registry.py | 0 .../node_handler/reshape_handler.py | 9 +- .../node_handler}/strategy/__init__.py | 17 +- .../strategy/batch_norm_generator.py | 8 +- .../strategy/conv_strategy_generator.py | 18 +- .../strategy/getitem_generator.py | 12 +- .../strategy/layer_norm_generator.py | 10 +- .../strategy/matmul_strategy_generator.py | 9 +- .../strategy/normal_pooling_generator.py | 12 +- .../strategy/output_generator.py | 11 +- .../strategy/placeholder_generator.py | 11 +- .../strategy/reshape_generator.py | 11 +- .../strategy/strategy_generator.py | 17 +- .../strategy/unary_elementwise_generator.py | 11 +- .../node_handler}/strategy/where_generator.py | 13 +- .../node_handler/unary_elementwise_handler.py | 9 +- .../node_handler/where_handler.py | 18 +- .../sharding_strategy.py | 18 +- .../tensor_shard/solver/__init__.py | 7 + .../{ => tensor_shard}/solver/cost_graph.py | 5 +- .../solver/graph_analysis.py | 7 +- .../{ => tensor_shard}/solver/options.py | 0 .../{ => tensor_shard}/solver/solver.py | 474 +----------------- .../solver/strategies_constructor.py | 27 +- .../tensor_shard/utils/__init__.py | 12 + .../utils}/broadcast.py | 0 .../utils/factory.py} | 73 +-- .../auto_parallel/tensor_shard/utils/misc.py | 26 + .../utils/sharding.py} | 51 +- .../test_tensor_shard/test_broadcast.py | 6 +- .../test_liveness_analysis.py | 7 +- .../test_batch_norm_handler.py | 10 +- .../test_node_handler/test_bmm_handler.py | 8 +- .../test_node_handler/test_conv_handler.py | 9 +- .../test_node_handler/test_getitem_handler.py | 13 +- .../test_layer_norm_handler.py | 10 +- .../test_node_handler/test_linear_handler.py | 10 +- .../test_norm_pooling_handler.py | 12 +- .../test_node_handler/test_output_handler.py | 8 +- .../test_placeholder_handler.py | 8 +- .../test_node_handler/test_reshape_handler.py | 11 +- .../test_unary_element_wise_handler.py | 13 +- .../test_node_handler/test_where_handler.py | 10 +- .../test_shape_consistency_pass.py | 26 +- .../test_solver_with_resnet_v2.py | 19 +- 57 files changed, 408 insertions(+), 799 deletions(-) delete mode 100644 colossalai/auto_parallel/solver/__init__.py rename colossalai/auto_parallel/{solver => tensor_shard}/constants.py (100%) rename colossalai/auto_parallel/{solver => tensor_shard}/node_handler/__init__.py (79%) rename colossalai/auto_parallel/{solver => tensor_shard}/node_handler/batch_norm_handler.py (88%) rename colossalai/auto_parallel/{solver => tensor_shard}/node_handler/conv_handler.py (96%) rename colossalai/auto_parallel/{solver => tensor_shard}/node_handler/dot_handler.py (96%) rename colossalai/auto_parallel/{solver => tensor_shard}/node_handler/getitem_handler.py (87%) rename colossalai/auto_parallel/{solver => tensor_shard}/node_handler/layer_norm_handler.py (91%) rename colossalai/auto_parallel/{solver => tensor_shard}/node_handler/node_handler.py (96%) rename colossalai/auto_parallel/{solver => tensor_shard}/node_handler/normal_pooling_handler.py (85%) rename colossalai/auto_parallel/{solver => tensor_shard}/node_handler/output_handler.py (80%) rename colossalai/auto_parallel/{solver => tensor_shard}/node_handler/placeholder_handler.py (72%) rename colossalai/auto_parallel/{solver => tensor_shard}/node_handler/registry.py (100%) rename colossalai/auto_parallel/{solver => tensor_shard}/node_handler/reshape_handler.py (86%) rename colossalai/auto_parallel/{solver => tensor_shard/node_handler}/strategy/__init__.py (77%) rename colossalai/auto_parallel/{solver => tensor_shard/node_handler}/strategy/batch_norm_generator.py (98%) rename colossalai/auto_parallel/{solver => tensor_shard/node_handler}/strategy/conv_strategy_generator.py (99%) rename colossalai/auto_parallel/{solver => tensor_shard/node_handler}/strategy/getitem_generator.py (97%) rename colossalai/auto_parallel/{solver => tensor_shard/node_handler}/strategy/layer_norm_generator.py (96%) rename colossalai/auto_parallel/{solver => tensor_shard/node_handler}/strategy/matmul_strategy_generator.py (99%) rename colossalai/auto_parallel/{solver => tensor_shard/node_handler}/strategy/normal_pooling_generator.py (94%) rename colossalai/auto_parallel/{solver => tensor_shard/node_handler}/strategy/output_generator.py (85%) rename colossalai/auto_parallel/{solver => tensor_shard/node_handler}/strategy/placeholder_generator.py (86%) rename colossalai/auto_parallel/{solver => tensor_shard/node_handler}/strategy/reshape_generator.py (97%) rename colossalai/auto_parallel/{solver => tensor_shard/node_handler}/strategy/strategy_generator.py (96%) rename colossalai/auto_parallel/{solver => tensor_shard/node_handler}/strategy/unary_elementwise_generator.py (94%) rename colossalai/auto_parallel/{solver => tensor_shard/node_handler}/strategy/where_generator.py (90%) rename colossalai/auto_parallel/{solver => tensor_shard}/node_handler/unary_elementwise_handler.py (85%) rename colossalai/auto_parallel/{solver => tensor_shard}/node_handler/where_handler.py (94%) rename colossalai/auto_parallel/{solver => tensor_shard}/sharding_strategy.py (94%) create mode 100644 colossalai/auto_parallel/tensor_shard/solver/__init__.py rename colossalai/auto_parallel/{ => tensor_shard}/solver/cost_graph.py (98%) rename colossalai/auto_parallel/{ => tensor_shard}/solver/graph_analysis.py (98%) rename colossalai/auto_parallel/{ => tensor_shard}/solver/options.py (100%) rename colossalai/auto_parallel/{ => tensor_shard}/solver/solver.py (50%) rename colossalai/auto_parallel/{ => tensor_shard}/solver/strategies_constructor.py (91%) create mode 100644 colossalai/auto_parallel/tensor_shard/utils/__init__.py rename colossalai/auto_parallel/{solver/node_handler => tensor_shard/utils}/broadcast.py (100%) rename colossalai/auto_parallel/{solver/_utils.py => tensor_shard/utils/factory.py} (70%) create mode 100644 colossalai/auto_parallel/tensor_shard/utils/misc.py rename colossalai/auto_parallel/{solver/node_handler/utils.py => tensor_shard/utils/sharding.py} (62%) rename tests/test_auto_parallel/{ => test_tensor_shard}/test_shape_consistency_pass.py (84%) diff --git a/colossalai/auto_parallel/solver/__init__.py b/colossalai/auto_parallel/solver/__init__.py deleted file mode 100644 index 15f951b85..000000000 --- a/colossalai/auto_parallel/solver/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from .sharding_strategy import ShardingStrategy, StrategiesVector -from .graph_analysis import GraphAnalyser -from .solver import Solver -from .cost_graph import CostGraph -from .strategies_constructor import StrategiesConstructor -from .constants import * -from .options import SolverOptions - -__all__ = [ - 'StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph', - 'SolverOptions' -] diff --git a/colossalai/auto_parallel/solver/constants.py b/colossalai/auto_parallel/tensor_shard/constants.py similarity index 100% rename from colossalai/auto_parallel/solver/constants.py rename to colossalai/auto_parallel/tensor_shard/constants.py diff --git a/colossalai/auto_parallel/solver/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py similarity index 79% rename from colossalai/auto_parallel/solver/node_handler/__init__.py rename to colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index 9aad0b91a..8e38d34ca 100644 --- a/colossalai/auto_parallel/solver/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -1,16 +1,17 @@ +from .batch_norm_handler import BatchNormModuleHandler +from .conv_handler import ConvFunctionHandler, ConvModuleHandler from .dot_handler import LinearFunctionHandler, LinearModuleHandler from .layer_norm_handler import LayerNormModuleHandler -from .batch_norm_handler import BatchNormModuleHandler -from .conv_handler import ConvModuleHandler, ConvFunctionHandler -from .where_handler import WhereHandler -from .unary_elementwise_handler import UnaryElementwiseHandler -from .reshape_handler import ReshapeHandler -from .placeholder_handler import PlacehodlerHandler -from .output_handler import OuputHandler from .normal_pooling_handler import NormPoolingHandler +from .output_handler import OuputHandler +from .placeholder_handler import PlacehodlerHandler +from .registry import operator_registry +from .reshape_handler import ReshapeHandler +from .unary_elementwise_handler import UnaryElementwiseHandler +from .where_handler import WhereHandler __all__ = [ 'LinearFunctionHandler', 'LinearModuleHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', - 'OuputHandler', 'WhereHandler', 'NormPoolingHandler' + 'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'operator_registry' ] diff --git a/colossalai/auto_parallel/solver/node_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py similarity index 88% rename from colossalai/auto_parallel/solver/node_handler/batch_norm_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py index 4a5e0fdec..1eaf304cf 100644 --- a/colossalai/auto_parallel/solver/node_handler/batch_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py @@ -1,10 +1,11 @@ +from typing import Dict, List + import torch -import torch.nn.functional as F -from .node_handler import ModuleHandler, NodeHandler -from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData -from ..strategy import BatchNormStrategyGenerator, StrategyGenerator -from typing import List, Dict + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import ModuleHandler from .registry import operator_registry +from .strategy import BatchNormStrategyGenerator, StrategyGenerator __all__ = ['BatchNormModuleHandler'] diff --git a/colossalai/auto_parallel/solver/node_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py similarity index 96% rename from colossalai/auto_parallel/solver/node_handler/conv_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py index 2074eeb1b..b678c59a5 100644 --- a/colossalai/auto_parallel/solver/node_handler/conv_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py @@ -1,12 +1,14 @@ +from typing import Dict, List + import torch import torch.nn.functional as F -from .node_handler import ModuleHandler, NodeHandler -from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData -from ..strategy import ConvStrategyGenerator, StrategyGenerator -from typing import List, Dict -from .registry import operator_registry -__all__ = ['LinearModuleHandler', 'LinearFunctionHandler'] +from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy) +from .node_handler import ModuleHandler, NodeHandler +from .registry import operator_registry +from .strategy import ConvStrategyGenerator, StrategyGenerator + +__all__ = ['ConvModuleHandler', 'ConvFunctionHandler'] @operator_registry.register(torch.nn.Conv1d) diff --git a/colossalai/auto_parallel/solver/node_handler/dot_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/dot_handler.py similarity index 96% rename from colossalai/auto_parallel/solver/node_handler/dot_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/dot_handler.py index 015f71ebd..ad90e4e5b 100644 --- a/colossalai/auto_parallel/solver/node_handler/dot_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/dot_handler.py @@ -1,13 +1,16 @@ +from copy import deepcopy +from typing import Dict, List, Union + import torch import torch.nn.functional as F + +from colossalai.auto_parallel.tensor_shard.utils import (switch_partition_dim, update_partition_dim) from colossalai.tensor.sharding_spec import ShardingException + +from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy) from .node_handler import ModuleHandler, NodeHandler -from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData -from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator, BatchedMatMulStrategyGenerator -from typing import List, Dict, Union from .registry import operator_registry -from copy import deepcopy -from .utils import switch_partition_dim, update_partition_dim +from .strategy import (BatchedMatMulStrategyGenerator, LinearProjectionStrategyGenerator, StrategyGenerator) __all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler'] diff --git a/colossalai/auto_parallel/solver/node_handler/getitem_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py similarity index 87% rename from colossalai/auto_parallel/solver/node_handler/getitem_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py index 9c4c8fdf0..25baa7766 100644 --- a/colossalai/auto_parallel/solver/node_handler/getitem_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py @@ -1,10 +1,12 @@ -import torch -from .node_handler import NodeHandler -from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector -from ..strategy import TensorStrategyGenerator, TensorTupleStrategyGenerator, StrategyGenerator -from typing import List, Dict -from .registry import operator_registry import operator +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import (StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator) __all__ = ['GetItemHandler'] diff --git a/colossalai/auto_parallel/solver/node_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py similarity index 91% rename from colossalai/auto_parallel/solver/node_handler/layer_norm_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py index 1bcb55daa..132ac30da 100644 --- a/colossalai/auto_parallel/solver/node_handler/layer_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py @@ -1,9 +1,11 @@ +from typing import Dict, List + import torch + +from ..sharding_strategy import OperationData, OperationDataType from .node_handler import ModuleHandler -from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData -from ..strategy import LayerNormGenerator, StrategyGenerator -from typing import List, Dict from .registry import operator_registry +from .strategy import LayerNormGenerator, StrategyGenerator __all__ = ['LayerNormModuleHandler'] diff --git a/colossalai/auto_parallel/solver/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py similarity index 96% rename from colossalai/auto_parallel/solver/node_handler/node_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 9d3421acd..bae458782 100644 --- a/colossalai/auto_parallel/solver/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -1,11 +1,14 @@ from abc import ABC, abstractmethod +from typing import Dict, List, Union + from torch.fx.node import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, ShardingStrategy, StrategiesVector, + TrainCycleItem) from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from typing import Dict, List, Union -from ..sharding_strategy import ShardingStrategy, StrategiesVector, OperationData, TrainCycleItem -from ..strategy import StrategyGenerator -from .._utils import generate_resharding_costs + +from .strategy import StrategyGenerator class NodeHandler(ABC): diff --git a/colossalai/auto_parallel/solver/node_handler/normal_pooling_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py similarity index 85% rename from colossalai/auto_parallel/solver/node_handler/normal_pooling_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py index 7238085a5..1509c05a3 100644 --- a/colossalai/auto_parallel/solver/node_handler/normal_pooling_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py @@ -1,10 +1,11 @@ +from typing import Dict, List + import torch -import torch.nn.functional as F -from .node_handler import ModuleHandler, NodeHandler -from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData -from ..strategy import NormalPoolStrategyGenerator, StrategyGenerator -from typing import List, Dict + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import ModuleHandler from .registry import operator_registry +from .strategy import NormalPoolStrategyGenerator, StrategyGenerator __all__ = ['NormPoolingHandler'] diff --git a/colossalai/auto_parallel/solver/node_handler/output_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py similarity index 80% rename from colossalai/auto_parallel/solver/node_handler/output_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py index a268bcc04..489e40daf 100644 --- a/colossalai/auto_parallel/solver/node_handler/output_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py @@ -1,10 +1,10 @@ +from typing import Dict, List + import torch + +from ..sharding_strategy import OperationData, OperationDataType from .node_handler import NodeHandler -from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector -from colossalai.auto_parallel.solver.strategy import StrategyGenerator -from colossalai.auto_parallel.solver.strategy.output_generator import OutputGenerator -from typing import List, Dict -from .registry import operator_registry +from .strategy import OutputGenerator, StrategyGenerator __all__ = ['OuputHandler'] diff --git a/colossalai/auto_parallel/solver/node_handler/placeholder_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py similarity index 72% rename from colossalai/auto_parallel/solver/node_handler/placeholder_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py index ab6b02a7b..88a02428e 100644 --- a/colossalai/auto_parallel/solver/node_handler/placeholder_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py @@ -1,10 +1,8 @@ -import torch +from typing import Dict, List + +from ..sharding_strategy import OperationData, OperationDataType from .node_handler import NodeHandler -from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData -from colossalai.auto_parallel.solver.strategy import StrategyGenerator -from colossalai.auto_parallel.solver.strategy.placeholder_generator import PlaceholderGenerator -from typing import List, Dict -from .registry import operator_registry +from .strategy import PlaceholderGenerator, StrategyGenerator __all__ = ['PlacehodlerHandler'] diff --git a/colossalai/auto_parallel/solver/node_handler/registry.py b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py similarity index 100% rename from colossalai/auto_parallel/solver/node_handler/registry.py rename to colossalai/auto_parallel/tensor_shard/node_handler/registry.py diff --git a/colossalai/auto_parallel/solver/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py similarity index 86% rename from colossalai/auto_parallel/solver/node_handler/reshape_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py index 8bb779290..1dd79e542 100644 --- a/colossalai/auto_parallel/solver/node_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py @@ -1,10 +1,11 @@ +from typing import Dict, List + import torch + +from ..sharding_strategy import OperationData, OperationDataType from .node_handler import NodeHandler -from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector -from ..strategy import ReshapeGenerator, StrategyGenerator -from typing import List, Dict from .registry import operator_registry -import operator +from .strategy import ReshapeGenerator, StrategyGenerator __all__ = ['ReshapeHandler'] diff --git a/colossalai/auto_parallel/solver/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py similarity index 77% rename from colossalai/auto_parallel/solver/strategy/__init__.py rename to colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py index 9d0e98c01..f137f09db 100644 --- a/colossalai/auto_parallel/solver/strategy/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py @@ -1,15 +1,16 @@ -from .strategy_generator import StrategyGenerator -from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator -from .conv_strategy_generator import ConvStrategyGenerator from .batch_norm_generator import BatchNormStrategyGenerator -from .unary_elementwise_generator import UnaryElementwiseGenerator -from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator +from .conv_strategy_generator import ConvStrategyGenerator +from .getitem_generator import (GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator) from .layer_norm_generator import LayerNormGenerator -from .where_generator import WhereGenerator -from .reshape_generator import ReshapeGenerator +from .matmul_strategy_generator import (BatchedMatMulStrategyGenerator, DotProductStrategyGenerator, + LinearProjectionStrategyGenerator, MatVecStrategyGenerator) from .normal_pooling_generator import NormalPoolStrategyGenerator -from .placeholder_generator import PlaceholderGenerator from .output_generator import OutputGenerator +from .placeholder_generator import PlaceholderGenerator +from .reshape_generator import ReshapeGenerator +from .strategy_generator import StrategyGenerator +from .unary_elementwise_generator import UnaryElementwiseGenerator +from .where_generator import WhereGenerator __all__ = [ 'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator', diff --git a/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py similarity index 98% rename from colossalai/auto_parallel/solver/strategy/batch_norm_generator.py rename to colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py index 1964c6eb8..70a2cc9b4 100644 --- a/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py @@ -1,11 +1,11 @@ +import copy import operator from functools import reduce -from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.tensor.shape_consistency import CollectiveCommPattern + from .strategy_generator import StrategyGenerator -from typing import List -from .._utils import exception_handler -import copy __all__ = ['BatchNormStrategyGenerator'] diff --git a/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py similarity index 99% rename from colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py rename to colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py index fcab52012..88d363447 100644 --- a/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py @@ -1,12 +1,14 @@ -import operator -from functools import reduce -from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost -from colossalai.tensor.shape_consistency import CollectiveCommPattern -from .strategy_generator import StrategyGenerator -from typing import List -from .._utils import exception_handler -import warnings import copy +import operator +import warnings +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.auto_parallel.tensor_shard.utils import exception_handler +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import StrategyGenerator class ConvStrategyGenerator(StrategyGenerator): diff --git a/colossalai/auto_parallel/solver/strategy/getitem_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py similarity index 97% rename from colossalai/auto_parallel/solver/strategy/getitem_generator.py rename to colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py index 646213032..59e0ee4c8 100644 --- a/colossalai/auto_parallel/solver/strategy/getitem_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py @@ -1,12 +1,10 @@ -import operator -from functools import reduce -from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost -from colossalai.tensor.shape_consistency import CollectiveCommPattern -from .strategy_generator import FollowingStrategyGenerator -from typing import List -from .._utils import exception_handler import copy +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import FollowingStrategyGenerator + __all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator'] diff --git a/colossalai/auto_parallel/solver/strategy/layer_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py similarity index 96% rename from colossalai/auto_parallel/solver/strategy/layer_norm_generator.py rename to colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py index 00bb0a8ca..86a70e5d0 100644 --- a/colossalai/auto_parallel/solver/strategy/layer_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py @@ -1,11 +1,13 @@ +import copy import operator from functools import reduce -from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding) from colossalai.tensor.shape_consistency import CollectiveCommPattern + from .strategy_generator import StrategyGenerator -from typing import List -from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding -import copy __all__ = ['LayerNormGenerator'] diff --git a/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py similarity index 99% rename from colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py rename to colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index 806959bb3..d36800e29 100644 --- a/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -1,11 +1,12 @@ -from audioop import bias import operator from functools import reduce -from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost -from colossalai.tensor.shape_consistency import CollectiveCommPattern -from .strategy_generator import StrategyGenerator from typing import List +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import StrategyGenerator + class MatMulStrategyGenerator(StrategyGenerator): """ diff --git a/colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py similarity index 94% rename from colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py rename to colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py index f54074622..59f6d89b5 100644 --- a/colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py @@ -1,11 +1,13 @@ +import copy import operator from functools import reduce -from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost -from colossalai.tensor.shape_consistency import CollectiveCommPattern -from .strategy_generator import StrategyGenerator from typing import List -from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding -import copy + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding) + +from .strategy_generator import StrategyGenerator class NormalPoolStrategyGenerator(StrategyGenerator): diff --git a/colossalai/auto_parallel/solver/strategy/output_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py similarity index 85% rename from colossalai/auto_parallel/solver/strategy/output_generator.py rename to colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py index bfd2ee9fd..3d58f0f11 100644 --- a/colossalai/auto_parallel/solver/strategy/output_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py @@ -1,11 +1,6 @@ -import operator -from functools import reduce -from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost -from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) + from .strategy_generator import OutputStrategyGenerator -from typing import List -from .._utils import exception_handler -import copy __all__ = ['OutputGenerator'] @@ -46,7 +41,7 @@ class OutputGenerator(OutputStrategyGenerator): communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = f'Replica Output' + name = 'Replica Output' strategy = self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, diff --git a/colossalai/auto_parallel/solver/strategy/placeholder_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py similarity index 86% rename from colossalai/auto_parallel/solver/strategy/placeholder_generator.py rename to colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py index 5e1940166..a106488a8 100644 --- a/colossalai/auto_parallel/solver/strategy/placeholder_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py @@ -1,11 +1,6 @@ -import operator -from functools import reduce -from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost -from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) + from .strategy_generator import StrategyGenerator -from typing import List -from .._utils import exception_handler -import copy __all__ = ['PlaceholderGenerator'] @@ -47,7 +42,7 @@ class PlaceholderGenerator(StrategyGenerator): communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = f'Replica Placeholder' + name = 'Replica Placeholder' strategy = self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, diff --git a/colossalai/auto_parallel/solver/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py similarity index 97% rename from colossalai/auto_parallel/solver/strategy/reshape_generator.py rename to colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py index 4ec45f5d3..eb5636fc8 100644 --- a/colossalai/auto_parallel/solver/strategy/reshape_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py @@ -1,11 +1,10 @@ -import operator -from functools import reduce -from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost -from colossalai.tensor.shape_consistency import CollectiveCommPattern -from .strategy_generator import FollowingStrategyGenerator -from typing import List import copy +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import FollowingStrategyGenerator + __all__ = ['ReshapeGenerator'] diff --git a/colossalai/auto_parallel/solver/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py similarity index 96% rename from colossalai/auto_parallel/solver/strategy/strategy_generator.py rename to colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py index 06bfe2a35..a643968ba 100644 --- a/colossalai/auto_parallel/solver/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -1,15 +1,16 @@ import operator -import torch -from colossalai.tensor.sharding_spec import ShardingSpec -from functools import reduce from abc import ABC, abstractmethod +from functools import reduce +from typing import Any, Dict, List, Union + +import torch +from torch.fx import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, + TrainCycleItem) +from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.device.device_mesh import DeviceMesh -from typing import Dict, List, Union, Any -from ..sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem, OperationDataType -from torch.fx import Node -import copy class StrategyGenerator(ABC): diff --git a/colossalai/auto_parallel/solver/strategy/unary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py similarity index 94% rename from colossalai/auto_parallel/solver/strategy/unary_elementwise_generator.py rename to colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py index 2a9220ca3..ea582588b 100644 --- a/colossalai/auto_parallel/solver/strategy/unary_elementwise_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py @@ -1,12 +1,9 @@ -import operator -from functools import reduce -from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost -from colossalai.tensor.shape_consistency import CollectiveCommPattern -from .strategy_generator import FollowingStrategyGenerator -from typing import List -from .._utils import exception_handler import copy +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) + +from .strategy_generator import FollowingStrategyGenerator + __all__ = ['UnaryElementwiseGenerator'] diff --git a/colossalai/auto_parallel/solver/strategy/where_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py similarity index 90% rename from colossalai/auto_parallel/solver/strategy/where_generator.py rename to colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py index bbf987ef2..48471cd73 100644 --- a/colossalai/auto_parallel/solver/strategy/where_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py @@ -1,12 +1,11 @@ -import operator -from functools import reduce -from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost -from colossalai.tensor.shape_consistency import CollectiveCommPattern -from .strategy_generator import StrategyGenerator, FollowingStrategyGenerator -from typing import List -from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding import copy +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding) + +from .strategy_generator import StrategyGenerator + __all__ = ['WhereGenerator'] diff --git a/colossalai/auto_parallel/solver/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py similarity index 85% rename from colossalai/auto_parallel/solver/node_handler/unary_elementwise_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py index 73ea4e6b9..b99d4a071 100644 --- a/colossalai/auto_parallel/solver/node_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -1,10 +1,11 @@ +from typing import Dict, List + import torch + +from ..sharding_strategy import OperationData, OperationDataType from .node_handler import NodeHandler -from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector -from ..strategy import UnaryElementwiseGenerator, StrategyGenerator -from typing import List, Dict from .registry import operator_registry -import operator +from .strategy import StrategyGenerator, UnaryElementwiseGenerator __all__ = ['UnaryElementwiseHandler'] diff --git a/colossalai/auto_parallel/solver/node_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py similarity index 94% rename from colossalai/auto_parallel/solver/node_handler/where_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py index 1e97ea919..ebcd6c453 100644 --- a/colossalai/auto_parallel/solver/node_handler/where_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py @@ -1,12 +1,14 @@ -import torch -from .node_handler import NodeHandler -from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector -from ..strategy import WhereGenerator, StrategyGenerator -from .broadcast import recover_sharding_spec_for_broadcast_shape -from typing import List, Dict -from .registry import operator_registry -import operator import copy +import operator +from typing import Dict, List + +import torch + +from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, StrategiesVector) +from ..utils import recover_sharding_spec_for_broadcast_shape +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator, WhereGenerator __all__ = ['WhereHandler'] diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py similarity index 94% rename from colossalai/auto_parallel/solver/sharding_strategy.py rename to colossalai/auto_parallel/tensor_shard/sharding_strategy.py index 5973c7250..70402a185 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py @@ -1,17 +1,14 @@ from copy import deepcopy from dataclasses import dataclass -from abc import ABC, abstractmethod from enum import Enum -import operator -import torch -from functools import reduce +from typing import Any, Dict, List, Tuple, Union -from colossalai.device.device_mesh import DeviceMesh +import torch +from colossalai.tensor.shape_consistency import CommSpec from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec -from typing import Dict, List, Union, Tuple, Any from torch.fx.node import Node -from .constants import * + +from .constants import (BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP) __all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector'] @@ -75,6 +72,11 @@ class TrainCycleItem: @dataclass class MemoryCost: """ + MemoryCost is a dataclass which stores the memory usage in the program. + + Args: + activation (int): the memory cost incurred by the activations in bytes. + parameter (int): the memory cost incurred by the module parameter in bytes. """ activation: int = 0 parameter: int = 0 diff --git a/colossalai/auto_parallel/tensor_shard/solver/__init__.py b/colossalai/auto_parallel/tensor_shard/solver/__init__.py new file mode 100644 index 000000000..e9f9ba881 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/solver/__init__.py @@ -0,0 +1,7 @@ +from .cost_graph import CostGraph +from .graph_analysis import GraphAnalyser +from .options import SolverOptions +from .solver import Solver +from .strategies_constructor import StrategiesConstructor + +__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph', 'SolverOptions'] diff --git a/colossalai/auto_parallel/solver/cost_graph.py b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py similarity index 98% rename from colossalai/auto_parallel/solver/cost_graph.py rename to colossalai/auto_parallel/tensor_shard/solver/cost_graph.py index b579f6587..16ce02cf1 100644 --- a/colossalai/auto_parallel/solver/cost_graph.py +++ b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py @@ -1,7 +1,4 @@ -from typing import List -import math -from torch.fx.node import Node -from colossalai.auto_parallel.solver.constants import INFINITY_COST +from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST class CostGraph: diff --git a/colossalai/auto_parallel/solver/graph_analysis.py b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py similarity index 98% rename from colossalai/auto_parallel/solver/graph_analysis.py rename to colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py index 831e7eadd..be39a74cb 100644 --- a/colossalai/auto_parallel/solver/graph_analysis.py +++ b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py @@ -1,9 +1,10 @@ from dataclasses import dataclass -from torch.fx.node import Node +from typing import List + from torch.fx.graph import Graph from torch.fx.graph_module import GraphModule -from collections import OrderedDict as ODict -from typing import List, OrderedDict, Union, Any +from torch.fx.node import Node + from colossalai.fx.passes.utils import get_node_module __all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser'] diff --git a/colossalai/auto_parallel/solver/options.py b/colossalai/auto_parallel/tensor_shard/solver/options.py similarity index 100% rename from colossalai/auto_parallel/solver/options.py rename to colossalai/auto_parallel/tensor_shard/solver/options.py diff --git a/colossalai/auto_parallel/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py similarity index 50% rename from colossalai/auto_parallel/solver/solver.py rename to colossalai/auto_parallel/tensor_shard/solver/solver.py index 97674c088..24783f8b0 100644 --- a/colossalai/auto_parallel/solver/solver.py +++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py @@ -1,18 +1,21 @@ -import warnings - -import time -import numpy as np import multiprocessing -from torch.fx.node import Node -from torch.fx.graph import Graph -from . import GraphAnalyser -from colossalai.auto_parallel.solver.cost_graph import CostGraph -from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +import time +import warnings from typing import Dict -from .constants import INFINITY_COST + +import numpy as np +from torch.fx.graph import Graph +from torch.fx.node import Node + +from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST + +from .cost_graph import CostGraph +from .graph_analysis import GraphAnalyser +from .strategies_constructor import StrategiesConstructor + try: import pulp - from pulp import LpVariable, LpProblem, LpMinimize, lpSum, lpDot, LpStatus + from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum except: warnings.warn(f'please install the pulp') @@ -21,454 +24,6 @@ __all___ = ['Solver'] class Solver: - def __init__(self, - graph: Graph, - strategies_constructor: StrategiesConstructor, - cost_graph: CostGraph, - graph_analyser: GraphAnalyser, - memory_budget: float = -1.0, - solution_numbers: int = 1, - memory_increasing_coefficient: float = 1.3): - ''' - Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. - - Argument: - graph: The computing graph to be optimized. - strategies_constructor: It will provide all the possible strategies for each node in the computing graph. - cost_graph: A graph data structure to simplify the edge cost graph. - graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints. - memory_budget: Memory constraint for the solution. - solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget. - memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget. - ''' - self.graph = graph - self.strategies_constructor = strategies_constructor - self.cost_graph = cost_graph - self.graph_analyser = graph_analyser - self.leaf_strategies = self.strategies_constructor.leaf_strategies - self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] - self.strategy_map = self.strategies_constructor.strategy_map - self.memory_budget = memory_budget - self.solution_numbers = solution_numbers - if self.solution_numbers > 1: - self.memory_increasing_coefficient = memory_increasing_coefficient - else: - self.memory_increasing_coefficient = 1 - self.liveness_list = self.graph_analyser.liveness_analysis() - self.node_index_dict = self._generate_node_index_dict() - # The last solution vector of auto sharding. - self.last_s_val = None - # The last objective value of the best ILP solution. - self.last_objective = None - - def _recover_merged_node_strategy(self): - ''' - During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node. - Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged - node. - ''' - for node_index, node in enumerate(self.nodes): - if node.strategies_vector.check_merge(): - # the merged node has only one input, and its strategies follow the input sharding strategy - input_strategies_vector = node.args[0].strategies_vector - input_best_strategy_index = self.last_s_val[node_index - 1] - input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec - for strategy_index, strategy in enumerate(node.strategies_vector): - if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence: - self.last_s_val[node_index] = strategy_index - break - - def _generate_node_index_dict(self) -> Dict[Node, int]: - node_index_dict = {} - for index, strategies_vector in enumerate(self.leaf_strategies): - node_index_dict[strategies_vector.node] = index - return node_index_dict - - def _prepare_data_for_solver(self): - ''' - Extract information from components for solver. - ''' - node_nums = len(self.leaf_strategies) - memory_budget = self.memory_budget - - # prepare strategies_len - strategies_len = [] - for node in self.nodes: - strategies_len.append(self.cost_graph.node_lens[node]) - strategies_len = np.array(strategies_len) - - # prepare following_nodes - following_nodes = self.cost_graph.following_dict - index_following_nodes = {} - for src, target in following_nodes.items(): - src_index = self.node_index_dict[src] - target_index = self.node_index_dict[target] - index_following_nodes[src_index] = target_index - following_nodes = index_following_nodes - for index in range(node_nums): - if index not in following_nodes: - following_nodes[index] = -1 - - # prepare edge_pairs and resharding costs - edge_pairs = [] - resharding_costs = [] - for pairs, edge_cost in self.cost_graph.edge_costs.items(): - src_node = pairs[0] - dst_node = pairs[1] - src_node_index = self.node_index_dict[src_node] - dst_node_index = self.node_index_dict[dst_node] - edge_pairs.append(src_node_index) - edge_pairs.append(dst_node_index) - - for i in range(strategies_len[src_node_index]): - for j in range(strategies_len[dst_node_index]): - resharding_costs.append(edge_cost[(i, j)]) - edge_pairs = np.array(edge_pairs) - resharding_costs = np.array(resharding_costs) - - # prepare liveness_set - liveness_set = self.liveness_list - - # omit alias_set now - alias_set = None - alias_convert_costs = None - - # prepare compute_costs, communication_costs and memory_costs - compute_costs = [] - communication_costs = [] - memory_costs = [] - extra_node_costs = self.cost_graph.extra_node_costs - for strategies_vector in self.leaf_strategies: - node = strategies_vector.node - for index, strategy in enumerate(strategies_vector): - compute_costs.append(strategy.compute_cost) - # node in extra_node_costs means it has some extra communication - # cost from node merging, so we need to add those extra communication - # cost into - if node in extra_node_costs: - origin_communication_cost = strategy.communication_cost - extra_node_cost = extra_node_costs[node][index] - communication_cost = origin_communication_cost + extra_node_cost - communication_costs.append(communication_cost) - else: - communication_costs.append(strategy.communication_cost) - # temporarily we just consider the forward memory cost - memory_cost = strategy.memory_cost - if isinstance(memory_cost, tuple): - memory_costs.append(memory_cost[0]) - else: - memory_costs.append(memory_cost) - compute_costs = np.array(compute_costs) - communication_costs = np.array(communication_costs) - memory_costs = np.array(memory_costs) - - # omit initial value for nodes - s_init_np = None - - return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np - - def _call_solver_serialized_args(self, - node_nums, - memory_budget, - strategies_len, - following_nodes, - edge_pairs, - alias_set, - liveness_set, - compute_costs, - communication_costs, - memory_costs, - resharding_costs, - alias_convert_costs, - s_init_np=None): - """ - Call the solver with serialized arguments. - """ - - tic = time.time() - - for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]: - assert isinstance(x, np.ndarray) - assert len(strategies_len) == node_nums, "strategies_len" - - def get_non_zero_index(binary_vector): - """ - Get the index of non-zero item in a vector. - """ - ct = 0 - ret = None - for i, elem in enumerate(binary_vector): - if pulp.value(elem): - ret = i - ct += 1 - - assert ct == 1 - return ret - - # 0. Unpack flatten numpy arrays - s_follow = following_nodes - - E = edge_pairs.reshape((-1, 2)) # noqa - r = [] - pt = 0 - edge_set = set() - for (i, j) in E: - prod_length = strategies_len[i] * strategies_len[j] - - if (i, j) in edge_set: - raise ValueError(f"Duplicated edges: {(i, j)}") - - edge_set.add((i, j)) - r.append(resharding_costs[pt:pt + prod_length]) - pt += prod_length - assert pt == len(resharding_costs) - - ###################### - # omit alias set now # - ###################### - - # A = alias_set.reshape((-1, 2)) # noqa - # for (i, j) in A: - # prod_length = strategies_len[i] * strategies_len[j] - # v.append(alias_convert_costs[pt:pt + prod_length]) - # pt += prod_length - # assert pt == len(alias_convert_costs) - - # L = [] # noqa - # pt = node_nums - # for i in range(node_nums): - # length = liveness_set[i] - # L.append(liveness_set[pt:pt + length]) - # pt += length - # assert pt == len(liveness_set) - v = [] - pt = 0 - - c = [] - d = [] - m = [] - pt = 0 - for i in range(node_nums): - length = strategies_len[i] - c.append(compute_costs[pt:pt + length]) - d.append(communication_costs[pt:pt + length]) - m.append(memory_costs[pt:pt + length]) - pt += length - assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}" - assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}" - assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}" - - # 1. Create variables - - ############################# - # create variables for node # - ############################# - s = [] - num_nodes = 0 - reverse_follow_backpatch = [] - for i in range(node_nums): - if s_follow[i] < 0: - if strategies_len[i] == 1: - s.append([1]) - else: - num_nodes += 1 - s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary")) - else: - if s_follow[i] < len(s): - s.append(s[s_follow[i]]) - else: - s.append(None) - reverse_follow_backpatch.append(i) - - for i in reverse_follow_backpatch: - s[i] = s[s_follow[i]] - - ############################# - # create variables for edge # - ############################# - e = [] - num_edges = 0 - for (idx, (i, j)) in enumerate(E): - if len(s[i]) == 1: - e.append(s[j]) - elif len(s[j]) == 1: - e.append(s[i]) - else: - num_edges += 1 - e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) - assert len(e[idx]) == len(r[idx]) - for element in s: - assert len(element) > 0 - # 2. Set initial value - ###################################### - # set a initial value for warm start # - ###################################### - if s_init_np is not None: - s_init = s_init_np.reshape((-1, 3)) - for (idx, value, fix) in s_init: - for i in range(len(s[idx])): - s[idx][i].setInitialValue(i == value) - if fix: - s[idx][i].fixValue() - - # 3. Objective - prob = LpProblem("myProblem", LpMinimize) - ################################################################### - # computing the node cost(computing cost and communication cost) # - ################################################################### - obj = 0 - for i in range(node_nums): - assert len(s[i]) == len(c[i]) - assert len(s[i]) == len(d[i]) - - obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i]) - - ############################################# - # computing the edge cost(resharding cost) # - ############################################# - for i in range(len(E)): - assert len(e[i]) == len(r[i]) - obj += lpDot(e[i], r[i]) - - prob += obj - - # 4. Constraints - # (a). specified by `cat="Binary"` - - # (b) - ################################################# - # make sure each node only choose one strategy # - ################################################# - for i in range(node_nums): - if s_follow[i] < 0: - prob += lpSum(s[i]) == 1 - - # (c) - ################################################# - # compute memory consumption with liveness set # - ################################################# - if memory_budget > 0: - for liveness_stage in liveness_set: - mem = 0 - for live_variable in liveness_stage.unique_live_vars: - node_index = self.node_index_dict[live_variable.node] - mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index]))) - prob += mem <= memory_budget - - # (d). specified by `cat="Binary"` - - for (idx, (i, j)) in enumerate(E): - if strategies_len[i] == 1 or strategies_len[j] == 1: - continue - - # (e) - prob += lpSum(e[idx]) == 1 - - # (f) - for row in range(len(s[i])): - C = len(s[j]) # noqa - prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row] - - # (g) - for col in range(len(s[j])): - R = len(s[i]) # noqa - C = len(s[j]) # noqa - prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col] - - # (h) - ###################### - # omit alias set now # - ###################### - - # alias_set = set() - # for (idx, (i, j)) in enumerate(A): - # R = len(s[i]) # noqa - # C = len(s[j]) # noqa - # if (i, j) in alias_set: - # raise ValueError(f"Duplicated edges: {(i, j)}") - - # alias_set.add((i, j)) - # alias_set.add((j, i)) - - # for row in range(len(s[i])): - # for col in range(len(s[j])): - # if v[idx][row * C + col] > 0.5: - # prob += s[i][row] + s[j][col] <= 1 - - verbose = True - - msg = verbose - time_limit = 600 - assert "COIN_CMD" in pulp.listSolvers( - onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'") - - solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count()) - # solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit) - prob.solve(solver) - - status = prob.status - objective = pulp.value(prob.objective) - objective = float(objective) if objective is not None else -1.0 - if verbose: - print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" - f"Time: {time.time() - tic}") - print(f"#nodes: {num_nodes}, #edges: {num_edges}") - - if prob.status in [pulp.LpStatusInfeasible]: - raise RuntimeError("Cannot run the function under the given memory budget. " - "Please increase the memory budget.") - - # Get and check results - s_val = np.full((node_nums,), -1, dtype=np.int32) - for i in range(node_nums): - s_val[i] = get_non_zero_index(s[i]) - - e_val = np.full((len(E),), -1, dtype=np.int32) - for (idx, (i, j)) in enumerate(E): - e_val[idx] = get_non_zero_index(e[idx]) - i_spec_index = e_val[idx] // len(s[j]) - j_spec_index = e_val[idx] % len(s[j]) - assert i_spec_index == s_val[i], f"e_val[{i}][{j}]" - assert j_spec_index == s_val[j], f"e_val[{i}][{j}]" - if verbose and r[idx][e_val[idx]] > 0: - print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}") - - self.last_s_val = list(s_val) - self._recover_merged_node_strategy() - self.last_objective = objective - - if objective > INFINITY_COST: - warnings.warn("Detect unexpected behaviors in the auto-sharding pass.") - - return self.last_s_val, e_val, self.last_objective, status - - def call_solver_serialized_args(self): - """ - Call the solver with serialized arguments and handle python errors. Additionally, - we could give a serious of solutions with different memory budget. - """ - if self.solution_numbers == 1: - args = self._prepare_data_for_solver() - ret = self._call_solver_serialized_args(*args) - - return ret - - origin_memory_budget = self.memory_budget - memory_budget_list = [ - origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers) - ] - ret_list = [] - for memory_budget in memory_budget_list: - self.memory_budget = memory_budget - args = self._prepare_data_for_solver() - ret = self._call_solver_serialized_args(*args) - ret_list.append(ret) - - return ret_list - - -class Solver_V2: - def __init__(self, graph: Graph, strategies_constructor: StrategiesConstructor, @@ -480,7 +35,6 @@ class Solver_V2: memory_increasing_coefficient: float = 1.3): ''' Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. - Argument: graph: The computing graph to be optimized. strategies_constructor: It will provide all the possible strategies for each node in the computing graph. diff --git a/colossalai/auto_parallel/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py similarity index 91% rename from colossalai/auto_parallel/solver/strategies_constructor.py rename to colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index f1bfa78bb..57d5dfa79 100644 --- a/colossalai/auto_parallel/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -1,22 +1,19 @@ +import math +import operator +from copy import deepcopy +from typing import Dict, List + +import torch from torch.fx import Graph, Node -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy -from colossalai.tensor.sharding_spec import ShardingSpec + +from colossalai.auto_parallel.tensor_shard.node_handler import (OuputHandler, PlacehodlerHandler, operator_registry) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (ShardingStrategy, StrategiesVector) +from colossalai.auto_parallel.tensor_shard.utils import (generate_resharding_costs, generate_sharding_spec) from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.auto_parallel.solver.node_handler.registry import operator_registry -from colossalai.auto_parallel.solver.node_handler.placeholder_handler import PlacehodlerHandler -from colossalai.auto_parallel.solver.node_handler.output_handler import OuputHandler +from colossalai.tensor.sharding_spec import ShardingSpec + from .options import SolverOptions -from . import ShardingStrategy, StrategiesVector -from .node_handler import * -from .constants import * -from copy import deepcopy -import math -import torch -import operator -from typing import Dict, List -from ._utils import generate_sharding_spec, generate_resharding_costs -import builtins __all__ = ['StrategiesConstructor'] diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py new file mode 100644 index 000000000..bfa9c18f4 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py @@ -0,0 +1,12 @@ +from .broadcast import (BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape) +from .factory import generate_resharding_costs, generate_sharding_spec +from .misc import exception_handler +from .sharding import (enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size, + switch_partition_dim, update_partition_dim) + +__all__ = [ + 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', + 'generate_resharding_costs', 'generate_sharding_spec', 'exception_handler', 'switch_partition_dim', + 'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding', + 'generate_sharding_size' +] diff --git a/colossalai/auto_parallel/solver/node_handler/broadcast.py b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py similarity index 100% rename from colossalai/auto_parallel/solver/node_handler/broadcast.py rename to colossalai/auto_parallel/tensor_shard/utils/broadcast.py diff --git a/colossalai/auto_parallel/solver/_utils.py b/colossalai/auto_parallel/tensor_shard/utils/factory.py similarity index 70% rename from colossalai/auto_parallel/solver/_utils.py rename to colossalai/auto_parallel/tensor_shard/utils/factory.py index 378a14d03..fd3ba3d41 100644 --- a/colossalai/auto_parallel/solver/_utils.py +++ b/colossalai/auto_parallel/tensor_shard/utils/factory.py @@ -1,14 +1,17 @@ -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -import torch -from torch.fx.node import Node -from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.device.device_mesh import DeviceMesh -from typing import Union, Dict, List, Optional +import operator import warnings from functools import reduce -import functools -import operator -from .constants import INFINITY_COST +from typing import Dict, List, Optional, Union + +import torch +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec +from torch.fx.node import Node + +from ..constants import INFINITY_COST + +__all__ = ['generate_sharding_spec', 'generate_resharding_costs'] def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, @@ -85,55 +88,3 @@ def generate_resharding_costs(nodes: List[Node], resharding_cost = INFINITY_COST resharding_costs[input_node].append(resharding_cost) return resharding_costs - - -def exception_handler(func): - """ - A function wrapper which executes the function with a specified seed. - """ - - @functools.wraps(func) - def wrapper(*args, **kwargs): - try: - rst = func(*args, **kwargs) - return rst - except AssertionError as e: - warnings.warn(f'{e}') - - return wrapper - - -def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size): - dim_partition_list = [] - # enumerate all the 2D sharding cases - for i in range(dim_size): - for j in range(i + 1, dim_size): - dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]} - dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]} - dim_partition_list.append(dim_partition_dict_0) - dim_partition_list.append(dim_partition_dict_1) - for i in range(dim_size): - dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]} - dim_partition_list.append(dim_partition_dict_flatten) - - return dim_partition_list - - -def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size): - dim_partition_list = [] - # enumerate all the 1D sharding cases - for i in range(dim_size): - dim_partition_dict_0 = {i: [mesh_dim_0]} - dim_partition_list.append(dim_partition_dict_0) - - return dim_partition_list - - -def generate_sharding_size(dim_partition_dict, device_mesh): - total_sharding_size = 1 - for mesh_dim_list in dim_partition_dict.values(): - mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list] - sharding_size = reduce(operator.mul, mesh_dim_sharding_size) - total_sharding_size *= sharding_size - - return total_sharding_size diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py new file mode 100644 index 000000000..7c8e530ff --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py @@ -0,0 +1,26 @@ +import functools +import warnings + +__all__ = ['exception_handler'] + + +def exception_handler(func): + """ + A function wrapper to handle the AssertionError in the function. + + Usage: + # mute the assertion error in the function + @exception_handler + def do_something(): + ... + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + rst = func(*args, **kwargs) + return rst + except AssertionError as e: + warnings.warn(f'{e}') + + return wrapper diff --git a/colossalai/auto_parallel/solver/node_handler/utils.py b/colossalai/auto_parallel/tensor_shard/utils/sharding.py similarity index 62% rename from colossalai/auto_parallel/solver/node_handler/utils.py rename to colossalai/auto_parallel/tensor_shard/utils/sharding.py index 59bd2f535..ae5d250a4 100644 --- a/colossalai/auto_parallel/solver/node_handler/utils.py +++ b/colossalai/auto_parallel/tensor_shard/utils/sharding.py @@ -1,7 +1,16 @@ -import torch -from typing import Dict -from colossalai.tensor.sharding_spec import ShardingSpec +import operator from copy import deepcopy +from functools import reduce +from typing import Dict + +import torch + +from colossalai.tensor.sharding_spec import ShardingSpec + +__all__ = [ + 'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', + 'enumerate_all_possible_2d_sharding', 'generate_sharding_size' +] def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec: @@ -66,3 +75,39 @@ def update_partition_dim(sharding_spec: ShardingSpec, entire_shape=physical_shape, dim_partition_dict=new_dim_partition_dict) return current_sharding_spec + + +def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size): + dim_partition_list = [] + # enumerate all the 2D sharding cases + for i in range(dim_size): + for j in range(i + 1, dim_size): + dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]} + dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]} + dim_partition_list.append(dim_partition_dict_0) + dim_partition_list.append(dim_partition_dict_1) + for i in range(dim_size): + dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]} + dim_partition_list.append(dim_partition_dict_flatten) + + return dim_partition_list + + +def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size): + dim_partition_list = [] + # enumerate all the 1D sharding cases + for i in range(dim_size): + dim_partition_dict_0 = {i: [mesh_dim_0]} + dim_partition_list.append(dim_partition_dict_0) + + return dim_partition_list + + +def generate_sharding_size(dim_partition_dict, device_mesh): + total_sharding_size = 1 + for mesh_dim_list in dim_partition_dict.values(): + mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list] + sharding_size = reduce(operator.mul, mesh_dim_sharding_size) + total_sharding_size *= sharding_size + + return total_sharding_size diff --git a/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py index 1a9279a78..4c35e7de5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py @@ -1,7 +1,9 @@ import torch -from colossalai.auto_parallel.solver.node_handler.broadcast import is_broadcastable, get_broadcast_shape, recover_sharding_spec_for_broadcast_shape -from colossalai.tensor.sharding_spec import ShardingSpec + +from colossalai.auto_parallel.tensor_shard.utils import (get_broadcast_shape, is_broadcastable, + recover_sharding_spec_for_broadcast_shape) from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec def test_is_broadcastable(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py index f54441729..f5de7bf70 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py @@ -1,7 +1,8 @@ -import torch.nn as nn import torch -from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser -from colossalai.fx import ColoTracer, ColoGraphModule +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser +from colossalai.fx import ColoGraphModule, ColoTracer class LinearModel(nn.Module): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py index 3bfb5e875..422474f6d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py @@ -1,10 +1,12 @@ -from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn -from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.node_handler.batch_norm_handler import BatchNormModuleHandler -from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector + +from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import \ + BatchNormModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx.tracer.meta_patch.patched_module import linear def test_bn_module_handler(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py index 9ec536743..f3612a781 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -1,10 +1,12 @@ import pytest import torch import torch.nn as nn -from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.node_handler.dot_handler import BMMFunctionHandler -from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector + +from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import \ + BMMFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.testing.pytest_wrapper import run_on_environment_flag diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py index 28643cdf0..ddce9f5eb 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -1,10 +1,11 @@ -from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn -from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvModuleHandler, ConvFunctionHandler -from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import (ConvFunctionHandler, ConvModuleHandler) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx.tracer.meta_patch.patched_module import linear def test_conv_module_handler(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index 97c03eae0..37a612de1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -1,11 +1,14 @@ -from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn -from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.node_handler.getitem_handler import GetItemHandler -from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler -from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \ + ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import \ + GetItemHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx.tracer.meta_patch.patched_module import linear class GetItemModel(nn.Module): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py index 3c942bd5e..1a8487e7e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py @@ -1,10 +1,12 @@ -from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn -from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.node_handler.layer_norm_handler import LayerNormModuleHandler -from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector + +from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import \ + LayerNormModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx.tracer.meta_patch.patched_module import linear def test_ln_module_handler(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index 4870a2ce1..fdd6a5198 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -1,10 +1,12 @@ -from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn -from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.node_handler.dot_handler import LinearModuleHandler, LinearFunctionHandler -from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector, ShardingStrategy + +from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import (LinearFunctionHandler, LinearModuleHandler) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, + StrategiesVector) from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.tensor.sharding_spec import ShardingSpec diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py index 423940558..7ff418f25 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py @@ -1,11 +1,13 @@ -from colossalai.fx.tracer.meta_patch.patched_module import linear +import pytest import torch import torch.nn as nn -from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.node_handler.normal_pooling_handler import NormPoolingHandler -from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector + +from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import \ + NormPoolingHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) from colossalai.device.device_mesh import DeviceMesh -import pytest +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.testing.pytest_wrapper import run_on_environment_flag diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py index e16bd6ba9..27b0af4fb 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py @@ -1,9 +1,11 @@ import torch import torch.nn as nn -from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.node_handler.output_handler import OuputHandler -from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector + +from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import \ + OuputHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer class OutputModel(nn.Module): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py index 66f31635c..bdec901e9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py @@ -1,9 +1,11 @@ import torch import torch.nn as nn -from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.node_handler.placeholder_handler import PlacehodlerHandler -from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector + +from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import \ + PlacehodlerHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer class PlaceholderModel(nn.Module): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py index 3249d10ee..b35fc64b6 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py @@ -1,10 +1,13 @@ import torch import torch.nn as nn -from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler -from colossalai.auto_parallel.solver.node_handler.reshape_handler import ReshapeHandler -from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \ + ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import \ + ReshapeHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer class ReshapeModel(nn.Module): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py index f79c81197..07f53a6cb 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py @@ -1,11 +1,14 @@ -from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn -from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.node_handler.unary_elementwise_handler import UnaryElementwiseHandler -from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler -from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \ + ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \ + UnaryElementwiseHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx.tracer.meta_patch.patched_module import linear class ReLuModel(nn.Module): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py index a81f1695d..9838e2eb0 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py @@ -1,10 +1,12 @@ -from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn -from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.node_handler.where_handler import WhereHandler -from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector + +from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import \ + WhereHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx.tracer.meta_patch.patched_module import linear class ConvModel(nn.Module): diff --git a/tests/test_auto_parallel/test_shape_consistency_pass.py b/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py similarity index 84% rename from tests/test_auto_parallel/test_shape_consistency_pass.py rename to tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py index 27a16a1cf..ae15106b0 100644 --- a/tests/test_auto_parallel/test_shape_consistency_pass.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py @@ -1,24 +1,22 @@ from functools import partial + import pytest import torch import torch.multiprocessing as mp -from torch.fx import GraphModule import torch.nn as nn -import pytest -from colossalai.initialize import launch -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.logging import disable_existing_loggers -from colossalai.auto_parallel.solver.cost_graph import CostGraph -from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser -from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from torch.fx import GraphModule -from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions, + StrategiesConstructor) from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass -from colossalai.auto_parallel.solver.solver import Solver_V2 -from colossalai.auto_parallel.solver.options import SolverOptions +from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (shape_consistency_pass, + solution_annotatation_pass) +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port class ConvModel(nn.Module): @@ -61,7 +59,7 @@ def check_apply(rank, world_size, port): cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph.simplify_graph() graph_analyser = GraphAnalyser(gm) - solver = Solver_V2(gm.graph, strategies_constructor, cost_graph, graph_analyser) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) ret = solver.call_solver_serialized_args() solution = list(ret[0]) device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py index a75337f10..23d866bbe 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -1,20 +1,13 @@ import torch from torch.fx import GraphModule -import torch.nn as nn -import pytest +from torchvision.models import resnet50 -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP +from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions, + StrategiesConstructor) from colossalai.device.device_mesh import DeviceMesh -from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor -from colossalai.auto_parallel.solver.cost_graph import CostGraph -from copy import deepcopy -from colossalai.auto_parallel.solver.solver import Solver -from torchvision.models import resnet34, resnet50 -from colossalai.auto_parallel.solver.constants import * -from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser -from colossalai.auto_parallel.solver.options import SolverOptions +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.testing.pytest_wrapper import run_on_environment_flag