[fx]add split module pass and unit test from pipeline passes (#1242)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx]add split module pass and unit test from pipeline passes

* fix MNASNet bug

* polish
This commit is contained in:
YuliangLiu0306
2022-07-12 13:45:01 +08:00
committed by GitHub
parent 762905da68
commit 30b4fc0eb0
11 changed files with 702 additions and 3 deletions

View File

@@ -2,7 +2,7 @@ import torch
from torch.fx import symbolic_trace
from torch.fx.node import Node
from torch.fx.passes.split_module import split_module
from colossalai.fx.passes.split_module import split_module
def pipe_split():
@@ -26,8 +26,14 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
if accumulate_param_amount >= params_per_partition:
accumulate_param_amount = 0
pp_size -= 1
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
# If the next node is output node, we will insert split annotation before
# node to make sure there is at least one node in last partition.
if node.next.op == 'output':
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
else:
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
return gm