diff --git a/colossalai/auto_parallel/solver/__init__.py b/colossalai/auto_parallel/solver/__init__.py index ec7817dfb..15f951b85 100644 --- a/colossalai/auto_parallel/solver/__init__.py +++ b/colossalai/auto_parallel/solver/__init__.py @@ -4,5 +4,9 @@ from .solver import Solver from .cost_graph import CostGraph from .strategies_constructor import StrategiesConstructor from .constants import * +from .options import SolverOptions -__all__ = ['StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph'] +__all__ = [ + 'StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph', + 'SolverOptions' +] diff --git a/colossalai/auto_parallel/solver/options.py b/colossalai/auto_parallel/solver/options.py new file mode 100644 index 000000000..2d34f5c64 --- /dev/null +++ b/colossalai/auto_parallel/solver/options.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + +__all__ = ['SolverOptions'] + + +@dataclass +class SolverOptions: + """ + SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. + """ + fast: bool = False diff --git a/colossalai/auto_parallel/solver/strategies_constructor.py b/colossalai/auto_parallel/solver/strategies_constructor.py index 546a30978..08867c591 100644 --- a/colossalai/auto_parallel/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/solver/strategies_constructor.py @@ -1,5 +1,8 @@ from torch.fx import Graph, Node from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from .options import SolverOptions from . import ShardingStrategy, StrategiesVector from .op_handler import * from .constants import * @@ -11,9 +14,20 @@ from typing import Dict, List class StrategiesConstructor: + """ + StrategiesConstructor is used to construct the parallelization plan for the model execution. - def __init__(self, graph, device_mesh, shape_consistency_manager, solver_options): + Args: + graph (Graph): a Graph object used for analysis and strategy generation. + device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. + shape_consistency_manager (ShapeConsistencyManager): a ShapeConsistencyManager object to make sure the sharding specs are consistent. + solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching. + """ + + def __init__(self, graph: Graph, device_mesh: DeviceMesh, shape_consistency_manager: ShapeConsistencyManager, + solver_options: SolverOptions): self.graph = graph + assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' self.root_module = self.graph.owning_module self.nodes = list(graph.nodes) self.device_mesh = device_mesh @@ -77,13 +91,13 @@ class StrategiesConstructor: strategies_vector = StrategiesVector(node) # placeholder node if node.op == 'placeholder': - # For placeholder nodes, if solver_options['fast_mode'] is True, we just let them in + # For placeholder nodes, if solver_options.fast is True, we just let them in # fully replicate status, then strategies of following node will be treated equally due # to replicate status has no resharding cost to other status. At the same time, the searching # space is smaller than enumerating all the possible sharding spec for the placeholder node. # Otherwise, all the possible sharding spec for the placeholder node will be enumerated. - if self.solver_options['fast_mode']: + if self.solver_options.fast: # create sharding strategy for placeholder name = 'Replica Placeholder' dim_partition_dict = {} @@ -97,12 +111,12 @@ class StrategiesConstructor: # get_attr node if node.op == 'get_attr': - # Same as placeholder nodes, if solver_options['fast_mode'] is True, we just let them in + # Same as placeholder nodes, if solver_options.fast is True, we just let them in # fully replicate status, then strategies of following node will be treated equally due # to replicate status has no resharding cost to other status. At the same time, the searching # space is smaller than enumerating all the possible sharding spec for the get_attr node. # Otherwise, all the possible sharding spec for the get_attr node will be enumerated. - if self.solver_options['fast_mode']: + if self.solver_options.fast: # create sharding strategy for get_attr name = 'Replica Attribute' dim_partition_dict = {} @@ -382,7 +396,7 @@ class StrategiesConstructor: # output node if node.op == 'output': - if self.solver_options['fast_mode']: + if self.solver_options.fast: # create sharding strategy for output name = 'Replica Output' input_nodes = strategies_vector.predecessor_nodes diff --git a/tests/test_auto_parallel/test_cost_graph.py b/tests/test_auto_parallel/test_cost_graph.py index 7d8232867..1bee5e35f 100644 --- a/tests/test_auto_parallel/test_cost_graph.py +++ b/tests/test_auto_parallel/test_cost_graph.py @@ -1,3 +1,4 @@ +from pickletools import optimize import torch from torch.fx import GraphModule import torch.nn as nn @@ -10,6 +11,7 @@ from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.device.device_mesh import DeviceMesh from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor from colossalai.auto_parallel.solver.cost_graph import CostGraph +from colossalai.auto_parallel.solver.options import SolverOptions from copy import deepcopy @@ -52,7 +54,7 @@ def test_cost_graph(): gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() - solver_options = {'fast_mode': True} + solver_options = SolverOptions(fast=True) strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options) strategies_constructor.build_strategies_and_cost() diff --git a/tests/test_auto_parallel/test_solver.py b/tests/test_auto_parallel/test_solver.py index 56b1052a3..ce8d2ba09 100644 --- a/tests/test_auto_parallel/test_solver.py +++ b/tests/test_auto_parallel/test_solver.py @@ -11,6 +11,7 @@ from colossalai.auto_parallel.solver.cost_graph import CostGraph from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser from copy import deepcopy from colossalai.auto_parallel.solver import Solver +from colossalai.auto_parallel.solver.options import SolverOptions class ConvModel(nn.Module): @@ -39,7 +40,6 @@ def test_solver(): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - entire_shape = torch.Size((4, 16, 64, 64)) shape_consistency_manager = ShapeConsistencyManager() tracer = ColoTracer() @@ -57,9 +57,8 @@ def test_solver(): # return relu graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() - solver_options = {'fast_mode': True} + solver_options = SolverOptions(fast=True) strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options) strategies_constructor.build_strategies_and_cost() diff --git a/tests/test_auto_parallel/test_strategies_constructor.py b/tests/test_auto_parallel/test_strategies_constructor.py index 37769d3c6..955bf43dd 100644 --- a/tests/test_auto_parallel/test_strategies_constructor.py +++ b/tests/test_auto_parallel/test_strategies_constructor.py @@ -11,6 +11,7 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.device.device_mesh import DeviceMesh from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.solver.options import SolverOptions from copy import deepcopy @@ -47,7 +48,7 @@ def test_strategies_constructor(): gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() - solver_options = {'fast_mode': True} + solver_options = SolverOptions(fast=True) strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options) assert strategies_constructor.leaf_strategies == []