mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[hotfix] autoparallel unit test (#1752)
This commit is contained in:
@@ -1,17 +1,18 @@
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
from copy import deepcopy
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
@@ -60,7 +61,7 @@ def test_solver():
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
|
Reference in New Issue
Block a user