[autoparallel] added solver option dataclass (#1588)

This commit is contained in:
Frank Lee
2022-09-13 14:47:09 +08:00
committed by GitHub
parent 82d4376c23
commit 219f66c571
6 changed files with 43 additions and 12 deletions

View File

@@ -4,5 +4,9 @@ from .solver import Solver
from .cost_graph import CostGraph
from .strategies_constructor import StrategiesConstructor
from .constants import *
from .options import SolverOptions
__all__ = ['StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
__all__ = [
'StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph',
'SolverOptions'
]

View File

@@ -0,0 +1,11 @@
from dataclasses import dataclass
__all__ = ['SolverOptions']
@dataclass
class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
fast: bool = False

View File

@@ -1,5 +1,8 @@
from torch.fx import Graph, Node
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from .options import SolverOptions
from . import ShardingStrategy, StrategiesVector
from .op_handler import *
from .constants import *
@@ -11,9 +14,20 @@ from typing import Dict, List
class StrategiesConstructor:
"""
StrategiesConstructor is used to construct the parallelization plan for the model execution.
def __init__(self, graph, device_mesh, shape_consistency_manager, solver_options):
Args:
graph (Graph): a Graph object used for analysis and strategy generation.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
shape_consistency_manager (ShapeConsistencyManager): a ShapeConsistencyManager object to make sure the sharding specs are consistent.
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
"""
def __init__(self, graph: Graph, device_mesh: DeviceMesh, shape_consistency_manager: ShapeConsistencyManager,
solver_options: SolverOptions):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
@@ -77,13 +91,13 @@ class StrategiesConstructor:
strategies_vector = StrategiesVector(node)
# placeholder node
if node.op == 'placeholder':
# For placeholder nodes, if solver_options['fast_mode'] is True, we just let them in
# For placeholder nodes, if solver_options.fast is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the placeholder node.
# Otherwise, all the possible sharding spec for the placeholder node will be enumerated.
if self.solver_options['fast_mode']:
if self.solver_options.fast:
# create sharding strategy for placeholder
name = 'Replica Placeholder'
dim_partition_dict = {}
@@ -97,12 +111,12 @@ class StrategiesConstructor:
# get_attr node
if node.op == 'get_attr':
# Same as placeholder nodes, if solver_options['fast_mode'] is True, we just let them in
# Same as placeholder nodes, if solver_options.fast is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
if self.solver_options['fast_mode']:
if self.solver_options.fast:
# create sharding strategy for get_attr
name = 'Replica Attribute'
dim_partition_dict = {}
@@ -382,7 +396,7 @@ class StrategiesConstructor:
# output node
if node.op == 'output':
if self.solver_options['fast_mode']:
if self.solver_options.fast:
# create sharding strategy for output
name = 'Replica Output'
input_nodes = strategies_vector.predecessor_nodes