[autoparallel] refactor the runtime apply pass and add docstring to passes (#1757)

* [autoparallel] refactor the runtime apply pass and add doc string to passes

* fix unit test

* polish
This commit is contained in:
YuliangLiu0306
2022-10-25 14:32:22 +08:00
committed by GitHub
parent f9a613d660
commit 314d8c497f
6 changed files with 289 additions and 205 deletions

View File

@@ -10,6 +10,8 @@ from torch.fx import GraphModule
from torchvision.models import resnet34, resnet50
from colossalai import device
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.constants import *
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
@@ -17,10 +19,6 @@ from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (
shape_consistency_pass,
solution_annotatation_pass,
)
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
@@ -153,8 +151,8 @@ def check_apply_bottleneck(rank, world_size, port):
print(solution)
for index, node in enumerate(graph.nodes):
print(node.name, node.strategies_vector[solution[index]].name)
sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh)
shape_consistency_pass(gm)
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
gm = runtime_apply_pass(gm)
gm.recompile()
nodes = [node for node in gm.graph.nodes]
# TODO: wrap the gm to avoid the influence of the user training code

View File

@@ -7,6 +7,8 @@ import torch.multiprocessing as mp
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
@@ -15,10 +17,6 @@ from colossalai.auto_parallel.tensor_shard.solver import (
StrategiesConstructor,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (
shape_consistency_pass,
solution_annotatation_pass,
)
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
@@ -72,8 +70,8 @@ def check_apply(rank, world_size, port):
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh)
shape_consistency_pass(gm)
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
gm = runtime_apply_pass(gm)
gm.recompile()
nodes = [node for node in gm.graph.nodes]
# TODO: wrap the gm to avoid the influence of the user training code