mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[autoparallel] added solver option dataclass (#1588)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from pickletools import optimize
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
@@ -10,6 +11,7 @@ from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
@@ -52,7 +54,7 @@ def test_cost_graph():
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
solver_options = {'fast_mode': True}
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
|
@@ -11,6 +11,7 @@ from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.solver import Solver
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
@@ -39,7 +40,6 @@ def test_solver():
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
@@ -57,9 +57,8 @@ def test_solver():
|
||||
# return relu
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
solver_options = {'fast_mode': True}
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
|
@@ -11,6 +11,7 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy,
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
@@ -47,7 +48,7 @@ def test_strategies_constructor():
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
solver_options = {'fast_mode': True}
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||
|
||||
assert strategies_constructor.leaf_strategies == []
|
||||
|
Reference in New Issue
Block a user