mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[autoparallel] refactor the runtime apply pass and add docstring to passes (#1757)
* [autoparallel] refactor the runtime apply pass and add doc string to passes * fix unit test * polish
This commit is contained in:
@@ -10,6 +10,8 @@ from torch.fx import GraphModule
|
||||
from torchvision.models import resnet34, resnet50
|
||||
|
||||
from colossalai import device
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.constants import *
|
||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
||||
@@ -17,10 +19,6 @@ from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (
|
||||
shape_consistency_pass,
|
||||
solution_annotatation_pass,
|
||||
)
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
@@ -153,8 +151,8 @@ def check_apply_bottleneck(rank, world_size, port):
|
||||
print(solution)
|
||||
for index, node in enumerate(graph.nodes):
|
||||
print(node.name, node.strategies_vector[solution[index]].name)
|
||||
sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh)
|
||||
shape_consistency_pass(gm)
|
||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
|
||||
gm = runtime_apply_pass(gm)
|
||||
gm.recompile()
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
# TODO: wrap the gm to avoid the influence of the user training code
|
||||
|
@@ -7,6 +7,8 @@ import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||
CostGraph,
|
||||
GraphAnalyser,
|
||||
@@ -15,10 +17,6 @@ from colossalai.auto_parallel.tensor_shard.solver import (
|
||||
StrategiesConstructor,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (
|
||||
shape_consistency_pass,
|
||||
solution_annotatation_pass,
|
||||
)
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
@@ -72,8 +70,8 @@ def check_apply(rank, world_size, port):
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh)
|
||||
shape_consistency_pass(gm)
|
||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
|
||||
gm = runtime_apply_pass(gm)
|
||||
gm.recompile()
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
# TODO: wrap the gm to avoid the influence of the user training code
|
||||
|
Reference in New Issue
Block a user