[autoparallel] refactored the autoparallel module for organization (#1706)

* [autoparallel] refactored the autoparallel module for organization

* polish code
This commit is contained in:
Frank Lee 2022-10-14 13:27:00 +08:00 committed by GitHub
parent 91cd34e6e0
commit 6c331a5a09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
57 changed files with 408 additions and 799 deletions

View File

@ -1,12 +0,0 @@
from .sharding_strategy import ShardingStrategy, StrategiesVector
from .graph_analysis import GraphAnalyser
from .solver import Solver
from .cost_graph import CostGraph
from .strategies_constructor import StrategiesConstructor
from .constants import *
from .options import SolverOptions
__all__ = [
'StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph',
'SolverOptions'
]

View File

@ -1,16 +1,17 @@
from .batch_norm_handler import BatchNormModuleHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .dot_handler import LinearFunctionHandler, LinearModuleHandler from .dot_handler import LinearFunctionHandler, LinearModuleHandler
from .layer_norm_handler import LayerNormModuleHandler from .layer_norm_handler import LayerNormModuleHandler
from .batch_norm_handler import BatchNormModuleHandler
from .conv_handler import ConvModuleHandler, ConvFunctionHandler
from .where_handler import WhereHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .reshape_handler import ReshapeHandler
from .placeholder_handler import PlacehodlerHandler
from .output_handler import OuputHandler
from .normal_pooling_handler import NormPoolingHandler from .normal_pooling_handler import NormPoolingHandler
from .output_handler import OuputHandler
from .placeholder_handler import PlacehodlerHandler
from .registry import operator_registry
from .reshape_handler import ReshapeHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .where_handler import WhereHandler
__all__ = [ __all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'LinearFunctionHandler', 'LinearModuleHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler',
'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler',
'OuputHandler', 'WhereHandler', 'NormPoolingHandler' 'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'operator_registry'
] ]

View File

@ -1,10 +1,11 @@
from typing import Dict, List
import torch import torch
import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler from ..sharding_strategy import OperationData, OperationDataType
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData from .node_handler import ModuleHandler
from ..strategy import BatchNormStrategyGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
from .strategy import BatchNormStrategyGenerator, StrategyGenerator
__all__ = ['BatchNormModuleHandler'] __all__ = ['BatchNormModuleHandler']

View File

@ -1,12 +1,14 @@
from typing import Dict, List
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import ConvStrategyGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler'] from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
from .node_handler import ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator
__all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
@operator_registry.register(torch.nn.Conv1d) @operator_registry.register(torch.nn.Conv1d)

View File

@ -1,13 +1,16 @@
from copy import deepcopy
from typing import Dict, List, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import (switch_partition_dim, update_partition_dim)
from colossalai.tensor.sharding_spec import ShardingException from colossalai.tensor.sharding_spec import ShardingException
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
from .node_handler import ModuleHandler, NodeHandler from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator, BatchedMatMulStrategyGenerator
from typing import List, Dict, Union
from .registry import operator_registry from .registry import operator_registry
from copy import deepcopy from .strategy import (BatchedMatMulStrategyGenerator, LinearProjectionStrategyGenerator, StrategyGenerator)
from .utils import switch_partition_dim, update_partition_dim
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler'] __all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler']

View File

@ -1,10 +1,12 @@
import torch
from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from ..strategy import TensorStrategyGenerator, TensorTupleStrategyGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry
import operator import operator
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import (StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator)
__all__ = ['GetItemHandler'] __all__ = ['GetItemHandler']

View File

@ -1,9 +1,11 @@
from typing import Dict, List
import torch import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import ModuleHandler from .node_handler import ModuleHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import LayerNormGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
from .strategy import LayerNormGenerator, StrategyGenerator
__all__ = ['LayerNormModuleHandler'] __all__ = ['LayerNormModuleHandler']

View File

@ -1,11 +1,14 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Union
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, ShardingStrategy, StrategiesVector,
TrainCycleItem)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from typing import Dict, List, Union
from ..sharding_strategy import ShardingStrategy, StrategiesVector, OperationData, TrainCycleItem from .strategy import StrategyGenerator
from ..strategy import StrategyGenerator
from .._utils import generate_resharding_costs
class NodeHandler(ABC): class NodeHandler(ABC):

View File

@ -1,10 +1,11 @@
from typing import Dict, List
import torch import torch
import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler from ..sharding_strategy import OperationData, OperationDataType
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData from .node_handler import ModuleHandler
from ..strategy import NormalPoolStrategyGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
__all__ = ['NormPoolingHandler'] __all__ = ['NormPoolingHandler']

View File

@ -1,10 +1,10 @@
from typing import Dict, List
import torch import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector from .strategy import OutputGenerator, StrategyGenerator
from colossalai.auto_parallel.solver.strategy import StrategyGenerator
from colossalai.auto_parallel.solver.strategy.output_generator import OutputGenerator
from typing import List, Dict
from .registry import operator_registry
__all__ = ['OuputHandler'] __all__ = ['OuputHandler']

View File

@ -1,10 +1,8 @@
import torch from typing import Dict, List
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData from .strategy import PlaceholderGenerator, StrategyGenerator
from colossalai.auto_parallel.solver.strategy import StrategyGenerator
from colossalai.auto_parallel.solver.strategy.placeholder_generator import PlaceholderGenerator
from typing import List, Dict
from .registry import operator_registry
__all__ = ['PlacehodlerHandler'] __all__ = ['PlacehodlerHandler']

View File

@ -1,10 +1,11 @@
from typing import Dict, List
import torch import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from ..strategy import ReshapeGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
import operator from .strategy import ReshapeGenerator, StrategyGenerator
__all__ = ['ReshapeHandler'] __all__ = ['ReshapeHandler']

View File

@ -1,15 +1,16 @@
from .strategy_generator import StrategyGenerator
from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator
from .conv_strategy_generator import ConvStrategyGenerator
from .batch_norm_generator import BatchNormStrategyGenerator from .batch_norm_generator import BatchNormStrategyGenerator
from .unary_elementwise_generator import UnaryElementwiseGenerator from .conv_strategy_generator import ConvStrategyGenerator
from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator from .getitem_generator import (GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator)
from .layer_norm_generator import LayerNormGenerator from .layer_norm_generator import LayerNormGenerator
from .where_generator import WhereGenerator from .matmul_strategy_generator import (BatchedMatMulStrategyGenerator, DotProductStrategyGenerator,
from .reshape_generator import ReshapeGenerator LinearProjectionStrategyGenerator, MatVecStrategyGenerator)
from .normal_pooling_generator import NormalPoolStrategyGenerator from .normal_pooling_generator import NormalPoolStrategyGenerator
from .placeholder_generator import PlaceholderGenerator
from .output_generator import OutputGenerator from .output_generator import OutputGenerator
from .placeholder_generator import PlaceholderGenerator
from .reshape_generator import ReshapeGenerator
from .strategy_generator import StrategyGenerator
from .unary_elementwise_generator import UnaryElementwiseGenerator
from .where_generator import WhereGenerator
__all__ = [ __all__ = [
'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator', 'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator',

View File

@ -1,11 +1,11 @@
import copy
import operator import operator
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
from typing import List
from .._utils import exception_handler
import copy
__all__ = ['BatchNormStrategyGenerator'] __all__ = ['BatchNormStrategyGenerator']

View File

@ -1,12 +1,14 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
from typing import List
from .._utils import exception_handler
import warnings
import copy import copy
import operator
import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import exception_handler
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
class ConvStrategyGenerator(StrategyGenerator): class ConvStrategyGenerator(StrategyGenerator):

View File

@ -1,12 +1,10 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
from typing import List
from .._utils import exception_handler
import copy import copy
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator'] __all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator']

View File

@ -1,11 +1,13 @@
import copy
import operator import operator
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding)
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
from typing import List
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
import copy
__all__ = ['LayerNormGenerator'] __all__ = ['LayerNormGenerator']

View File

@ -1,11 +1,12 @@
from audioop import bias
import operator import operator
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
from typing import List from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
class MatMulStrategyGenerator(StrategyGenerator): class MatMulStrategyGenerator(StrategyGenerator):
""" """

View File

@ -1,11 +1,13 @@
import copy
import operator import operator
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
from typing import List from typing import List
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
import copy from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding)
from .strategy_generator import StrategyGenerator
class NormalPoolStrategyGenerator(StrategyGenerator): class NormalPoolStrategyGenerator(StrategyGenerator):

View File

@ -1,11 +1,6 @@
import operator from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import OutputStrategyGenerator from .strategy_generator import OutputStrategyGenerator
from typing import List
from .._utils import exception_handler
import copy
__all__ = ['OutputGenerator'] __all__ = ['OutputGenerator']
@ -46,7 +41,7 @@ class OutputGenerator(OutputStrategyGenerator):
communication_action_mapping = {} communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = f'Replica Output' name = 'Replica Output'
strategy = self.get_sharding_strategy(name=name, strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,

View File

@ -1,11 +1,6 @@
import operator from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
from typing import List
from .._utils import exception_handler
import copy
__all__ = ['PlaceholderGenerator'] __all__ = ['PlaceholderGenerator']
@ -47,7 +42,7 @@ class PlaceholderGenerator(StrategyGenerator):
communication_action_mapping = {} communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = f'Replica Placeholder' name = 'Replica Placeholder'
strategy = self.get_sharding_strategy(name=name, strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,

View File

@ -1,11 +1,10 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
from typing import List
import copy import copy
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['ReshapeGenerator'] __all__ = ['ReshapeGenerator']

View File

@ -1,15 +1,16 @@
import operator import operator
import torch
from colossalai.tensor.sharding_spec import ShardingSpec
from functools import reduce
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import reduce
from typing import Any, Dict, List, Union
import torch
from torch.fx import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy,
TrainCycleItem)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from typing import Dict, List, Union, Any
from ..sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem, OperationDataType
from torch.fx import Node
import copy
class StrategyGenerator(ABC): class StrategyGenerator(ABC):

View File

@ -1,12 +1,9 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
from typing import List
from .._utils import exception_handler
import copy import copy
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['UnaryElementwiseGenerator'] __all__ = ['UnaryElementwiseGenerator']

View File

@ -1,12 +1,11 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator, FollowingStrategyGenerator
from typing import List
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
import copy import copy
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding)
from .strategy_generator import StrategyGenerator
__all__ = ['WhereGenerator'] __all__ = ['WhereGenerator']

View File

@ -1,10 +1,11 @@
from typing import Dict, List
import torch import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from ..strategy import UnaryElementwiseGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
import operator from .strategy import StrategyGenerator, UnaryElementwiseGenerator
__all__ = ['UnaryElementwiseHandler'] __all__ = ['UnaryElementwiseHandler']

View File

@ -1,12 +1,14 @@
import torch
from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from ..strategy import WhereGenerator, StrategyGenerator
from .broadcast import recover_sharding_spec_for_broadcast_shape
from typing import List, Dict
from .registry import operator_registry
import operator
import copy import copy
import operator
from typing import Dict, List
import torch
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, StrategiesVector)
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, WhereGenerator
__all__ = ['WhereHandler'] __all__ = ['WhereHandler']

View File

@ -1,17 +1,14 @@
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
import operator from typing import Any, Dict, List, Tuple, Union
import torch
from functools import reduce
from colossalai.device.device_mesh import DeviceMesh import torch
from colossalai.tensor.shape_consistency import CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from typing import Dict, List, Union, Tuple, Any
from torch.fx.node import Node from torch.fx.node import Node
from .constants import *
from .constants import (BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP)
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector'] __all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
@ -75,6 +72,11 @@ class TrainCycleItem:
@dataclass @dataclass
class MemoryCost: class MemoryCost:
""" """
MemoryCost is a dataclass which stores the memory usage in the program.
Args:
activation (int): the memory cost incurred by the activations in bytes.
parameter (int): the memory cost incurred by the module parameter in bytes.
""" """
activation: int = 0 activation: int = 0
parameter: int = 0 parameter: int = 0

View File

@ -0,0 +1,7 @@
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
from .options import SolverOptions
from .solver import Solver
from .strategies_constructor import StrategiesConstructor
__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph', 'SolverOptions']

View File

@ -1,7 +1,4 @@
from typing import List from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
import math
from torch.fx.node import Node
from colossalai.auto_parallel.solver.constants import INFINITY_COST
class CostGraph: class CostGraph:

View File

@ -1,9 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from torch.fx.node import Node from typing import List
from torch.fx.graph import Graph from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule from torch.fx.graph_module import GraphModule
from collections import OrderedDict as ODict from torch.fx.node import Node
from typing import List, OrderedDict, Union, Any
from colossalai.fx.passes.utils import get_node_module from colossalai.fx.passes.utils import get_node_module
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser'] __all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']

View File

@ -1,18 +1,21 @@
import warnings
import time
import numpy as np
import multiprocessing import multiprocessing
from torch.fx.node import Node import time
from torch.fx.graph import Graph import warnings
from . import GraphAnalyser
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from typing import Dict from typing import Dict
from .constants import INFINITY_COST
import numpy as np
from torch.fx.graph import Graph
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
from .strategies_constructor import StrategiesConstructor
try: try:
import pulp import pulp
from pulp import LpVariable, LpProblem, LpMinimize, lpSum, lpDot, LpStatus from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
except: except:
warnings.warn(f'please install the pulp') warnings.warn(f'please install the pulp')
@ -21,454 +24,6 @@ __all___ = ['Solver']
class Solver: class Solver:
def __init__(self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser,
memory_budget: float = -1.0,
solution_numbers: int = 1,
memory_increasing_coefficient: float = 1.3):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph.
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
'''
self.graph = graph
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
self.graph_analyser = graph_analyser
self.leaf_strategies = self.strategies_constructor.leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
self.strategy_map = self.strategies_constructor.strategy_map
self.memory_budget = memory_budget
self.solution_numbers = solution_numbers
if self.solution_numbers > 1:
self.memory_increasing_coefficient = memory_increasing_coefficient
else:
self.memory_increasing_coefficient = 1
self.liveness_list = self.graph_analyser.liveness_analysis()
self.node_index_dict = self._generate_node_index_dict()
# The last solution vector of auto sharding.
self.last_s_val = None
# The last objective value of the best ILP solution.
self.last_objective = None
def _recover_merged_node_strategy(self):
'''
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
'''
for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
input_strategies_vector = node.args[0].strategies_vector
input_best_strategy_index = self.last_s_val[node_index - 1]
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
for strategy_index, strategy in enumerate(node.strategies_vector):
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
self.last_s_val[node_index] = strategy_index
break
def _generate_node_index_dict(self) -> Dict[Node, int]:
node_index_dict = {}
for index, strategies_vector in enumerate(self.leaf_strategies):
node_index_dict[strategies_vector.node] = index
return node_index_dict
def _prepare_data_for_solver(self):
'''
Extract information from components for solver.
'''
node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget
# prepare strategies_len
strategies_len = []
for node in self.nodes:
strategies_len.append(self.cost_graph.node_lens[node])
strategies_len = np.array(strategies_len)
# prepare following_nodes
following_nodes = self.cost_graph.following_dict
index_following_nodes = {}
for src, target in following_nodes.items():
src_index = self.node_index_dict[src]
target_index = self.node_index_dict[target]
index_following_nodes[src_index] = target_index
following_nodes = index_following_nodes
for index in range(node_nums):
if index not in following_nodes:
following_nodes[index] = -1
# prepare edge_pairs and resharding costs
edge_pairs = []
resharding_costs = []
for pairs, edge_cost in self.cost_graph.edge_costs.items():
src_node = pairs[0]
dst_node = pairs[1]
src_node_index = self.node_index_dict[src_node]
dst_node_index = self.node_index_dict[dst_node]
edge_pairs.append(src_node_index)
edge_pairs.append(dst_node_index)
for i in range(strategies_len[src_node_index]):
for j in range(strategies_len[dst_node_index]):
resharding_costs.append(edge_cost[(i, j)])
edge_pairs = np.array(edge_pairs)
resharding_costs = np.array(resharding_costs)
# prepare liveness_set
liveness_set = self.liveness_list
# omit alias_set now
alias_set = None
alias_convert_costs = None
# prepare compute_costs, communication_costs and memory_costs
compute_costs = []
communication_costs = []
memory_costs = []
extra_node_costs = self.cost_graph.extra_node_costs
for strategies_vector in self.leaf_strategies:
node = strategies_vector.node
for index, strategy in enumerate(strategies_vector):
compute_costs.append(strategy.compute_cost)
# node in extra_node_costs means it has some extra communication
# cost from node merging, so we need to add those extra communication
# cost into
if node in extra_node_costs:
origin_communication_cost = strategy.communication_cost
extra_node_cost = extra_node_costs[node][index]
communication_cost = origin_communication_cost + extra_node_cost
communication_costs.append(communication_cost)
else:
communication_costs.append(strategy.communication_cost)
# temporarily we just consider the forward memory cost
memory_cost = strategy.memory_cost
if isinstance(memory_cost, tuple):
memory_costs.append(memory_cost[0])
else:
memory_costs.append(memory_cost)
compute_costs = np.array(compute_costs)
communication_costs = np.array(communication_costs)
memory_costs = np.array(memory_costs)
# omit initial value for nodes
s_init_np = None
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np
def _call_solver_serialized_args(self,
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np=None):
"""
Call the solver with serialized arguments.
"""
tic = time.time()
for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]:
assert isinstance(x, np.ndarray)
assert len(strategies_len) == node_nums, "strategies_len"
def get_non_zero_index(binary_vector):
"""
Get the index of non-zero item in a vector.
"""
ct = 0
ret = None
for i, elem in enumerate(binary_vector):
if pulp.value(elem):
ret = i
ct += 1
assert ct == 1
return ret
# 0. Unpack flatten numpy arrays
s_follow = following_nodes
E = edge_pairs.reshape((-1, 2)) # noqa
r = []
pt = 0
edge_set = set()
for (i, j) in E:
prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j))
r.append(resharding_costs[pt:pt + prod_length])
pt += prod_length
assert pt == len(resharding_costs)
######################
# omit alias set now #
######################
# A = alias_set.reshape((-1, 2)) # noqa
# for (i, j) in A:
# prod_length = strategies_len[i] * strategies_len[j]
# v.append(alias_convert_costs[pt:pt + prod_length])
# pt += prod_length
# assert pt == len(alias_convert_costs)
# L = [] # noqa
# pt = node_nums
# for i in range(node_nums):
# length = liveness_set[i]
# L.append(liveness_set[pt:pt + length])
# pt += length
# assert pt == len(liveness_set)
v = []
pt = 0
c = []
d = []
m = []
pt = 0
for i in range(node_nums):
length = strategies_len[i]
c.append(compute_costs[pt:pt + length])
d.append(communication_costs[pt:pt + length])
m.append(memory_costs[pt:pt + length])
pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"
# 1. Create variables
#############################
# create variables for node #
#############################
s = []
num_nodes = 0
reverse_follow_backpatch = []
for i in range(node_nums):
if s_follow[i] < 0:
if strategies_len[i] == 1:
s.append([1])
else:
num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
else:
if s_follow[i] < len(s):
s.append(s[s_follow[i]])
else:
s.append(None)
reverse_follow_backpatch.append(i)
for i in reverse_follow_backpatch:
s[i] = s[s_follow[i]]
#############################
# create variables for edge #
#############################
e = []
num_edges = 0
for (idx, (i, j)) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
e.append(s[i])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])
for element in s:
assert len(element) > 0
# 2. Set initial value
######################################
# set a initial value for warm start #
######################################
if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3))
for (idx, value, fix) in s_init:
for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value)
if fix:
s[idx][i].fixValue()
# 3. Objective
prob = LpProblem("myProblem", LpMinimize)
###################################################################
# computing the node cost(computing cost and communication cost) #
###################################################################
obj = 0
for i in range(node_nums):
assert len(s[i]) == len(c[i])
assert len(s[i]) == len(d[i])
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
#############################################
# computing the edge cost(resharding cost) #
#############################################
for i in range(len(E)):
assert len(e[i]) == len(r[i])
obj += lpDot(e[i], r[i])
prob += obj
# 4. Constraints
# (a). specified by `cat="Binary"`
# (b)
#################################################
# make sure each node only choose one strategy #
#################################################
for i in range(node_nums):
if s_follow[i] < 0:
prob += lpSum(s[i]) == 1
# (c)
#################################################
# compute memory consumption with liveness set #
#################################################
if memory_budget > 0:
for liveness_stage in liveness_set:
mem = 0
for live_variable in liveness_stage.unique_live_vars:
node_index = self.node_index_dict[live_variable.node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget
# (d). specified by `cat="Binary"`
for (idx, (i, j)) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1:
continue
# (e)
prob += lpSum(e[idx]) == 1
# (f)
for row in range(len(s[i])):
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
# (g)
for col in range(len(s[j])):
R = len(s[i]) # noqa
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
# (h)
######################
# omit alias set now #
######################
# alias_set = set()
# for (idx, (i, j)) in enumerate(A):
# R = len(s[i]) # noqa
# C = len(s[j]) # noqa
# if (i, j) in alias_set:
# raise ValueError(f"Duplicated edges: {(i, j)}")
# alias_set.add((i, j))
# alias_set.add((j, i))
# for row in range(len(s[i])):
# for col in range(len(s[j])):
# if v[idx][row * C + col] > 0.5:
# prob += s[i][row] + s[j][col] <= 1
verbose = True
msg = verbose
time_limit = 600
assert "COIN_CMD" in pulp.listSolvers(
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
prob.solve(solver)
status = prob.status
objective = pulp.value(prob.objective)
objective = float(objective) if objective is not None else -1.0
if verbose:
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
f"Time: {time.time() - tic}")
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
if prob.status in [pulp.LpStatusInfeasible]:
raise RuntimeError("Cannot run the function under the given memory budget. "
"Please increase the memory budget.")
# Get and check results
s_val = np.full((node_nums,), -1, dtype=np.int32)
for i in range(node_nums):
s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E),), -1, dtype=np.int32)
for (idx, (i, j)) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j])
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
if verbose and r[idx][e_val[idx]] > 0:
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
self.last_s_val = list(s_val)
self._recover_merged_node_strategy()
self.last_objective = objective
if objective > INFINITY_COST:
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
return self.last_s_val, e_val, self.last_objective, status
def call_solver_serialized_args(self):
"""
Call the solver with serialized arguments and handle python errors. Additionally,
we could give a serious of solutions with different memory budget.
"""
if self.solution_numbers == 1:
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
return ret
origin_memory_budget = self.memory_budget
memory_budget_list = [
origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers)
]
ret_list = []
for memory_budget in memory_budget_list:
self.memory_budget = memory_budget
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
ret_list.append(ret)
return ret_list
class Solver_V2:
def __init__(self, def __init__(self,
graph: Graph, graph: Graph,
strategies_constructor: StrategiesConstructor, strategies_constructor: StrategiesConstructor,
@ -480,7 +35,6 @@ class Solver_V2:
memory_increasing_coefficient: float = 1.3): memory_increasing_coefficient: float = 1.3):
''' '''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument: Argument:
graph: The computing graph to be optimized. graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph. strategies_constructor: It will provide all the possible strategies for each node in the computing graph.

View File

@ -1,22 +1,19 @@
import math
import operator
from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx import Graph, Node from torch.fx import Graph, Node
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.auto_parallel.tensor_shard.node_handler import (OuputHandler, PlacehodlerHandler, operator_registry)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.auto_parallel.tensor_shard.utils import (generate_resharding_costs, generate_sharding_spec)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.auto_parallel.solver.node_handler.registry import operator_registry from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.solver.node_handler.placeholder_handler import PlacehodlerHandler
from colossalai.auto_parallel.solver.node_handler.output_handler import OuputHandler
from .options import SolverOptions from .options import SolverOptions
from . import ShardingStrategy, StrategiesVector
from .node_handler import *
from .constants import *
from copy import deepcopy
import math
import torch
import operator
from typing import Dict, List
from ._utils import generate_sharding_spec, generate_resharding_costs
import builtins
__all__ = ['StrategiesConstructor'] __all__ = ['StrategiesConstructor']

View File

@ -0,0 +1,12 @@
from .broadcast import (BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape)
from .factory import generate_resharding_costs, generate_sharding_spec
from .misc import exception_handler
from .sharding import (enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size,
switch_partition_dim, update_partition_dim)
__all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'exception_handler', 'switch_partition_dim',
'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding',
'generate_sharding_size'
]

View File

@ -1,14 +1,17 @@
from colossalai.tensor.shape_consistency import ShapeConsistencyManager import operator
import torch
from torch.fx.node import Node
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from typing import Union, Dict, List, Optional
import warnings import warnings
from functools import reduce from functools import reduce
import functools from typing import Dict, List, Optional, Union
import operator
from .constants import INFINITY_COST import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from torch.fx.node import Node
from ..constants import INFINITY_COST
__all__ = ['generate_sharding_spec', 'generate_resharding_costs']
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
@ -85,55 +88,3 @@ def generate_resharding_costs(nodes: List[Node],
resharding_cost = INFINITY_COST resharding_cost = INFINITY_COST
resharding_costs[input_node].append(resharding_cost) resharding_costs[input_node].append(resharding_cost)
return resharding_costs return resharding_costs
def exception_handler(func):
"""
A function wrapper which executes the function with a specified seed.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
rst = func(*args, **kwargs)
return rst
except AssertionError as e:
warnings.warn(f'{e}')
return wrapper
def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
dim_partition_list = []
# enumerate all the 2D sharding cases
for i in range(dim_size):
for j in range(i + 1, dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
dim_partition_list.append(dim_partition_dict_1)
for i in range(dim_size):
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
dim_partition_list.append(dim_partition_dict_flatten)
return dim_partition_list
def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
dim_partition_list = []
# enumerate all the 1D sharding cases
for i in range(dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
return dim_partition_list
def generate_sharding_size(dim_partition_dict, device_mesh):
total_sharding_size = 1
for mesh_dim_list in dim_partition_dict.values():
mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
total_sharding_size *= sharding_size
return total_sharding_size

View File

@ -0,0 +1,26 @@
import functools
import warnings
__all__ = ['exception_handler']
def exception_handler(func):
"""
A function wrapper to handle the AssertionError in the function.
Usage:
# mute the assertion error in the function
@exception_handler
def do_something():
...
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
rst = func(*args, **kwargs)
return rst
except AssertionError as e:
warnings.warn(f'{e}')
return wrapper

View File

@ -1,7 +1,16 @@
import torch import operator
from typing import Dict
from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy from copy import deepcopy
from functools import reduce
from typing import Dict
import torch
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
]
def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec: def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
@ -66,3 +75,39 @@ def update_partition_dim(sharding_spec: ShardingSpec,
entire_shape=physical_shape, entire_shape=physical_shape,
dim_partition_dict=new_dim_partition_dict) dim_partition_dict=new_dim_partition_dict)
return current_sharding_spec return current_sharding_spec
def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
dim_partition_list = []
# enumerate all the 2D sharding cases
for i in range(dim_size):
for j in range(i + 1, dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
dim_partition_list.append(dim_partition_dict_1)
for i in range(dim_size):
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
dim_partition_list.append(dim_partition_dict_flatten)
return dim_partition_list
def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
dim_partition_list = []
# enumerate all the 1D sharding cases
for i in range(dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
return dim_partition_list
def generate_sharding_size(dim_partition_dict, device_mesh):
total_sharding_size = 1
for mesh_dim_list in dim_partition_dict.values():
mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
total_sharding_size *= sharding_size
return total_sharding_size

View File

@ -1,7 +1,9 @@
import torch import torch
from colossalai.auto_parallel.solver.node_handler.broadcast import is_broadcastable, get_broadcast_shape, recover_sharding_spec_for_broadcast_shape
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.auto_parallel.tensor_shard.utils import (get_broadcast_shape, is_broadcastable,
recover_sharding_spec_for_broadcast_shape)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
def test_is_broadcastable(): def test_is_broadcastable():

View File

@ -1,7 +1,8 @@
import torch.nn as nn
import torch import torch
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
from colossalai.fx import ColoGraphModule, ColoTracer
class LinearModel(nn.Module): class LinearModel(nn.Module):

View File

@ -1,10 +1,12 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.batch_norm_handler import BatchNormModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import \
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector BatchNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
def test_bn_module_handler(): def test_bn_module_handler():

View File

@ -1,10 +1,12 @@
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.dot_handler import BMMFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import \
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector BMMFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag

View File

@ -1,10 +1,11 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvModuleHandler, ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import (ConvFunctionHandler, ConvModuleHandler)
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
def test_conv_module_handler(): def test_conv_module_handler():

View File

@ -1,11 +1,14 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.getitem_handler import GetItemHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler ConvFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import \
GetItemHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
class GetItemModel(nn.Module): class GetItemModel(nn.Module):

View File

@ -1,10 +1,12 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.layer_norm_handler import LayerNormModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import \
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector LayerNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
def test_ln_module_handler(): def test_ln_module_handler():

View File

@ -1,10 +1,12 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.dot_handler import LinearModuleHandler, LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import (LinearFunctionHandler, LinearModuleHandler)
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector, ShardingStrategy from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy,
StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec

View File

@ -1,11 +1,13 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.normal_pooling_handler import NormPoolingHandler from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import \
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector NormPoolingHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
import pytest from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag

View File

@ -1,9 +1,11 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.output_handler import OuputHandler from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import \
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector OuputHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
class OutputModel(nn.Module): class OutputModel(nn.Module):

View File

@ -1,9 +1,11 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.placeholder_handler import PlacehodlerHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import \
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector PlacehodlerHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
class PlaceholderModel(nn.Module): class PlaceholderModel(nn.Module):

View File

@ -1,10 +1,13 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
from colossalai.auto_parallel.solver.node_handler.reshape_handler import ReshapeHandler ConvFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import \
ReshapeHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
class ReshapeModel(nn.Module): class ReshapeModel(nn.Module):

View File

@ -1,11 +1,14 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.unary_elementwise_handler import UnaryElementwiseHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler ConvFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \
UnaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
class ReLuModel(nn.Module): class ReLuModel(nn.Module):

View File

@ -1,10 +1,12 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.where_handler import WhereHandler from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import \
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector WhereHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
class ConvModel(nn.Module): class ConvModel(nn.Module):

View File

@ -1,24 +1,22 @@
from functools import partial from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.fx import GraphModule
import torch.nn as nn import torch.nn as nn
import pytest from torch.fx import GraphModule
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
StrategiesConstructor)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (shape_consistency_pass,
from colossalai.auto_parallel.solver.solver import Solver_V2 solution_annotatation_pass)
from colossalai.auto_parallel.solver.options import SolverOptions from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
class ConvModel(nn.Module): class ConvModel(nn.Module):
@ -61,7 +59,7 @@ def check_apply(rank, world_size, port):
cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph() cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm) graph_analyser = GraphAnalyser(gm)
solver = Solver_V2(gm.graph, strategies_constructor, cost_graph, graph_analyser) solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args() ret = solver.call_solver_serialized_args()
solution = list(ret[0]) solution = list(ret[0])
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh() device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()

View File

@ -1,20 +1,13 @@
import torch import torch
from torch.fx import GraphModule from torch.fx import GraphModule
import torch.nn as nn from torchvision.models import resnet50
import pytest
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
from colossalai.tensor.shape_consistency import ShapeConsistencyManager StrategiesConstructor)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.solver.cost_graph import CostGraph from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from copy import deepcopy
from colossalai.auto_parallel.solver.solver import Solver
from torchvision.models import resnet34, resnet50
from colossalai.auto_parallel.solver.constants import *
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.solver.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag