diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py index 23ed0f433..012b0ff43 100644 --- a/colossalai/auto_parallel/tensor_shard/initialize.py +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -8,14 +8,9 @@ from torch.fx.graph import Graph from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction -from colossalai.auto_parallel.tensor_shard.solver import ( - CostGraph, - GraphAnalyser, - Solver, - SolverOptions, - StrategiesConstructor, -) +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.graph_module import ColoGraphModule @@ -69,13 +64,43 @@ def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[f pass -def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh): +def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, + shard_option: str): ''' This method is used to build the strategy_constructor for the given graph. After this method, each node in the graph will have a strategies_vector which is constructed by the related node handler. ''' - solver_options = SolverOptions() + if solver_preference == 'standard': + solver_preference = SolverPerference.STANDARD + elif solver_preference == 'tp': + solver_preference = SolverPerference.TP + elif solver_preference == 'dp': + solver_preference = SolverPerference.DP + else: + raise ValueError(f'Invalid solver_preference: {solver_preference}') + + if dataloader_option == 'replicated': + dataloader_option = DataloaderOption.REPLICATED + elif dataloader_option == 'distributed': + dataloader_option = DataloaderOption.DISTRIBUTED + else: + raise ValueError(f'Invalid dataloader_option: {dataloader_option}') + + if shard_option == 'standard': + shard_option = ShardOption.STANDARD + elif shard_option == 'shard': + shard_option = ShardOption.SHARD + elif shard_option == 'shard_last_axis': + shard_option = ShardOption.SHARD_LAST_AXIS + elif shard_option == 'full_shard': + shard_option = ShardOption.FULL_SHARD + else: + raise ValueError(f'Invalid shard_option: {shard_option}') + + solver_options = SolverOptions(solver_perference=solver_preference, + dataloader_option=dataloader_option, + shard_option=shard_option) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() @@ -183,6 +208,9 @@ def initialize_model(model: nn.Module, device_mesh: DeviceMesh, memory_budget: float = -1.0, overlap: bool = False, + solver_preference: str = 'standard', + dataloader_option: str = 'replicated', + shard_option: str = 'standard', save_solver_solution: bool = False, load_solver_solution: bool = False, solution_path: str = None, @@ -198,6 +226,12 @@ def initialize_model(model: nn.Module, the memory budget will be infinity. overlap(optional): the overlap is used to specify whether to overlap gradient communication and backward computing. + solver_preference(optional): the solver_preference is used to specify which parallelism algorithm + has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'. + dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will + be used. The valid dataloader_option could be 'replicated' or 'distributed'. + shard_option(optional): the shard_option is used to specify how many axes will be used to shard the + model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'. save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved to the solution_path. load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded @@ -212,7 +246,12 @@ def initialize_model(model: nn.Module, graph = tracer.trace(root=model, meta_args=meta_args) gm = ColoGraphModule(model, graph, model.__class__.__name__) gm.recompile() - strategies_constructor = build_strategy_constructor(graph, device_mesh) + + strategies_constructor = build_strategy_constructor(graph, + device_mesh, + solver_preference=solver_preference, + dataloader_option=dataloader_option, + shard_option=shard_option) if load_solver_solution: solution = torch.load(solution_path) else: @@ -240,6 +279,9 @@ def autoparallelize(model: nn.Module, alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, logical_mesh_shape: Tuple[int] = None, logical_mesh_id: torch.Tensor = None, + solver_preference: str = 'standard', + dataloader_option: str = 'replicated', + shard_option: str = 'standard', save_solver_solution: bool = False, load_solver_solution: bool = False, solver_solution_path: str = None, @@ -262,6 +304,12 @@ def autoparallelize(model: nn.Module, mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be generated by search_best_logical_mesh_shape function. logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id. + solver_preference(optional): the solver_preference is used to specify which parallelism algorithm + has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'. + dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will + be used. The valid dataloader_option could be 'replicated' or 'distributed'. + shard_option(optional): the shard_option is used to specify how many axes will be used to shard the + model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'. save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved to the solution_path. load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded @@ -280,6 +328,8 @@ def autoparallelize(model: nn.Module, rst_to_unpack = initialize_model(model, meta_args, device_mesh, + solver_preference=solver_preference, + dataloader_option=dataloader_option, save_solver_solution=save_solver_solution, load_solver_solution=load_solver_solution, solution_path=solver_solution_path, diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index 0050358ce..9903ca54e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -11,7 +11,6 @@ from .layer_norm_handler import LayerNormModuleHandler from .linear_handler import LinearFunctionHandler, LinearModuleHandler from .matmul_handler import MatMulHandler from .normal_pooling_handler import NormPoolingHandler -from .option import ShardOption from .output_handler import OutputHandler from .permute_handler import PermuteHandler from .placeholder_handler import PlaceholderHandler @@ -31,6 +30,6 @@ __all__ = [ 'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler', 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', - 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption', - 'TransposeHandler', 'SplitHandler' + 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler', + 'SplitHandler' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index c6f8d035a..136e57c5e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -5,7 +5,7 @@ import torch from torch.fx.node import Node from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register -from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption +from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, @@ -32,19 +32,19 @@ class NodeHandler(ABC): strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. ''' - def __init__( - self, - node: Node, - device_mesh: DeviceMesh, - strategies_vector: StrategiesVector, - shard_option: ShardOption = ShardOption.STANDARD, - ) -> None: + def __init__(self, + node: Node, + device_mesh: DeviceMesh, + strategies_vector: StrategiesVector, + shard_option: ShardOption = ShardOption.STANDARD, + solver_perference: SolverPerference = SolverPerference.STANDARD) -> None: self.node = node self.predecessor_node = list(node._input_nodes.keys()) self.successor_node = list(node.users.keys()) self.device_mesh = device_mesh self.strategies_vector = strategies_vector self.shard_option = shard_option + self.solver_perference = solver_perference def update_resharding_cost(self, strategy: ShardingStrategy) -> None: """ @@ -187,15 +187,24 @@ class NodeHandler(ABC): remove_strategy_list = [] for strategy in self.strategies_vector: - shard_level = 0 + shard_axis_list = [] + last_axis = len(self.device_mesh.mesh_shape) - 1 for op_data, sharding_spec in strategy.sharding_specs.items(): if op_data.data is not None and isinstance(op_data.data, torch.Tensor): - for dim, shard_axis in sharding_spec.dim_partition_dict.items(): - shard_level += len(shard_axis) + for dim, shard_axes in sharding_spec.dim_partition_dict.items(): + for shard_axis in shard_axes: + if shard_axis not in shard_axis_list: + shard_axis_list.append(shard_axis) + + shard_level = len(shard_axis_list) + using_last_axis = last_axis in shard_axis_list or -1 in shard_axis_list if self.shard_option == ShardOption.SHARD and shard_level == 0: remove_strategy_list.append(strategy) if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1: remove_strategy_list.append(strategy) + if self.shard_option == ShardOption.SHARD_LAST_AXIS: + if shard_level != 1 or using_last_axis == False: + remove_strategy_list.append(strategy) for strategy in remove_strategy_list: self.strategies_vector.remove(strategy) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/option.py b/colossalai/auto_parallel/tensor_shard/node_handler/option.py deleted file mode 100644 index dffb0386d..000000000 --- a/colossalai/auto_parallel/tensor_shard/node_handler/option.py +++ /dev/null @@ -1,17 +0,0 @@ -from enum import Enum - -__all__ = ['ShardOption'] - - -class ShardOption(Enum): - """ - This enum class is to define the shard level required in node strategies. - - Notes: - STANDARD: We do not add any extra shard requirements. - SHARD: We require the node to be shard using at least one device mesh axis. - FULL_SHARD: We require the node to be shard using all device mesh axes. - """ - STANDARD = 0 - SHARD = 1 - FULL_SHARD = 2 diff --git a/colossalai/auto_parallel/tensor_shard/options.py b/colossalai/auto_parallel/tensor_shard/options.py new file mode 100644 index 000000000..f0ea502a6 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/options.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass +from enum import Enum + +__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption'] + + +class SolverPerference(Enum): + """ + This enum class is to define the solver preference. + """ + STANDARD = 0 + DP = 1 + TP = 2 + + +class ShardOption(Enum): + """ + This enum class is to define the shard level required in node strategies. + + Notes: + STANDARD: We do not add any extra shard requirements. + SHARD: We require the node to be shard using at least one device mesh axis. + SHARD_ONE_AXIS: We require the node to be shard using the last device mesh axis. + FULL_SHARD: We require the node to be shard using all device mesh axes. + TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis. + TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes. + """ + STANDARD = 0 + SHARD = 1 + SHARD_LAST_AXIS = 2 + FULL_SHARD = 3 + + +class DataloaderOption(Enum): + """ + This enum class is to define the dataloader option. + """ + REPLICATED = 0 + DISTRIBUTED = 1 + + +@dataclass +class SolverOptions: + """ + SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. + """ + solver_perference: SolverPerference = SolverPerference.STANDARD + dataloader_option: DataloaderOption = DataloaderOption.REPLICATED + shard_option: ShardOption = ShardOption.STANDARD diff --git a/colossalai/auto_parallel/tensor_shard/solver/__init__.py b/colossalai/auto_parallel/tensor_shard/solver/__init__.py index e9f9ba881..f9e6bd923 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/solver/__init__.py @@ -1,7 +1,6 @@ from .cost_graph import CostGraph from .graph_analysis import GraphAnalyser -from .options import SolverOptions from .solver import Solver from .strategies_constructor import StrategiesConstructor -__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph', 'SolverOptions'] +__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph'] diff --git a/colossalai/auto_parallel/tensor_shard/solver/options.py b/colossalai/auto_parallel/tensor_shard/solver/options.py deleted file mode 100644 index b52e55708..000000000 --- a/colossalai/auto_parallel/tensor_shard/solver/options.py +++ /dev/null @@ -1,30 +0,0 @@ -from dataclasses import dataclass -from enum import Enum - -__all__ = ['SolverOptions'] - - -class SolverPerference(Enum): - """ - This enum class is to define the solver preference. - """ - STANDARD = 0 - DP = 1 - TP = 2 - - -class DataloaderOption(Enum): - """ - This enum class is to define the dataloader option. - """ - REPLICATED = 0 - DISTRIBUTED = 1 - - -@dataclass -class SolverOptions: - """ - SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. - """ - solver_perference: SolverPerference = SolverPerference.STANDARD - dataloader_option: DataloaderOption = DataloaderOption.REPLICATED diff --git a/colossalai/auto_parallel/tensor_shard/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py index 89d0da223..3bc3e8960 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/solver.py +++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py @@ -33,7 +33,7 @@ class Solver: solution_numbers: int = 1, forward_only: bool = False, memory_increasing_coefficient: float = 1.3, - verbose=True): + verbose=False): ''' Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. Argument: diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index 042b9bb4b..40741daca 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -17,7 +17,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVe from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec from colossalai.device.device_mesh import DeviceMesh -from .options import DataloaderOption, SolverOptions +from ..options import DataloaderOption, SolverOptions __all__ = ['StrategiesConstructor'] @@ -101,7 +101,11 @@ class StrategiesConstructor: # get_attr node elif node.op == 'get_attr': - getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector) + getattr_handler = GetattrHandler(node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference) getattr_handler.register_strategy() # call_module node @@ -109,7 +113,11 @@ class StrategiesConstructor: target = node.target submod = self.root_module.get_submodule(target) submod_type = type(submod) - handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector) + handler = operator_registry.get(submod_type)(node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference) handler.register_strategy() # attach metainfo_vector to node if hasattr(handler, 'metainfo_vector'): @@ -118,7 +126,11 @@ class StrategiesConstructor: # call_function node elif node.op == 'call_function': target = node.target - handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector) + handler = operator_registry.get(target)(node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference) handler.register_strategy() # attach metainfo_vector to node if hasattr(handler, 'metainfo_vector'): @@ -127,7 +139,11 @@ class StrategiesConstructor: # call_method node elif node.op == 'call_method': method = getattr(node.args[0]._meta_data.__class__, node.target) - handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector) + handler = operator_registry.get(method)(node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference) handler.register_strategy() # attach metainfo_vector to node if hasattr(handler, 'metainfo_vector'): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py index 26ad0d3a0..a6be1928b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py @@ -4,13 +4,8 @@ import transformers from torch.fx import GraphModule from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP -from colossalai.auto_parallel.tensor_shard.solver import ( - CostGraph, - GraphAnalyser, - Solver, - SolverOptions, - StrategiesConstructor, -) +from colossalai.auto_parallel.tensor_shard.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.shape_consistency import ShapeConsistencyManager diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py index b8c01d358..60ecd1dd9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py @@ -7,8 +7,9 @@ from torch.fx import GraphModule from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType, TrainCycleItem -from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.tracer.tracer import ColoTracer diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py index fda041110..f6895d92a 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py @@ -5,7 +5,7 @@ import torch.multiprocessing as mp import torch.nn as nn from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler -from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption +from colossalai.auto_parallel.tensor_shard.options import ShardOption from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer @@ -49,6 +49,15 @@ def check_shard_option(shard_option): strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] + if shard_option == ShardOption.SHARD_LAST_AXIS: + # RR = RS x SR + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS1 = RR x RS1' in strategy_name_list + + return + # SS = SR x RS assert 'S1S0 = S1R x RS0_0' in strategy_name_list assert 'S0S1 = S0R x RS1_1' in strategy_name_list @@ -104,7 +113,8 @@ def check_shard_option(shard_option): @run_on_environment_flag(name='AUTO_PARALLEL') def test_shard_option(): - for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD]: + # for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]: + for shard_option in [ShardOption.SHARD_LAST_AXIS]: check_shard_option(shard_option) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py index db76ed9b8..14c8cb296 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py @@ -6,7 +6,8 @@ from torch.fx import GraphModule from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass -from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser from colossalai.auto_parallel.tensor_shard.solver.solver import Solver diff --git a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py index b504d59c9..92f011ba3 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py @@ -1,13 +1,8 @@ import torch +from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType -from colossalai.auto_parallel.tensor_shard.solver import ( - CostGraph, - GraphAnalyser, - Solver, - SolverOptions, - StrategiesConstructor, -) +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.testing.pytest_wrapper import run_on_environment_flag diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py index f4a5ae7ac..6f64acd52 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -3,13 +3,8 @@ from torch.fx import GraphModule from torchvision.models import resnet50 from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP -from colossalai.auto_parallel.tensor_shard.solver import ( - CostGraph, - GraphAnalyser, - Solver, - SolverOptions, - StrategiesConstructor, -) +from colossalai.auto_parallel.tensor_shard.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.shape_consistency import ShapeConsistencyManager