From cdb7d5e7d2dc2a48c31f0625d109728973213af8 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Thu, 20 Oct 2022 19:51:38 +0800 Subject: [PATCH] [hotfix] autoparallel unit test (#1752) --- .../deprecated/op_handler/__init__.py | 11 +++++---- .../test_deprecated/test_deprecated_solver.py | 23 ++++++++++--------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py index efcaae795..723e1bcf9 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py @@ -1,14 +1,15 @@ -from .operator_handler import OperatorHandler -from .dot_handler import DotHandler -from .conv_handler import ConvHandler from .batch_norm_handler import BatchNormHandler -from .reshape_handler import ReshapeHandler from .bcast_op_handler import BcastOpHandler +from .conv_handler import ConvHandler +from .dot_handler import DotHandler from .embedding_handler import EmbeddingHandler +from .layer_norm_handler import LayerNormHandler +from .operator_handler import OperatorHandler +from .reshape_handler import ReshapeHandler from .unary_elementwise_handler import UnaryElementwiseHandler from .where_handler import WhereHandler __all__ = [ 'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler', - 'UnaryElementwiseHandler', 'EmbeddingHandler', 'WhereHandler' + 'UnaryElementwiseHandler', 'EmbeddingHandler', 'WhereHandler', 'LayerNormHandler' ] diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py index 65bbd6bc3..baa70727a 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py @@ -1,17 +1,18 @@ -import torch -from torch.fx import GraphModule -import torch.nn as nn -import pytest +from copy import deepcopy -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.device.device_mesh import DeviceMesh -from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule + +from colossalai.auto_parallel.tensor_shard.deprecated import Solver from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser -from copy import deepcopy -from colossalai.auto_parallel.tensor_shard.deprecated import Solver from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.testing.pytest_wrapper import run_on_environment_flag @@ -60,7 +61,7 @@ def test_solver(): gm = GraphModule(model, graph, model.__class__.__name__) solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() cost_graph = CostGraph(strategies_constructor.leaf_strategies)