[fx] update split module pass and add customized policy (#1373)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx]update split module pass and add customized policy
This commit is contained in:
YuliangLiu0306
2022-07-27 13:40:54 +08:00
committed by GitHub
parent be229217ce
commit 52bc2dc271
2 changed files with 85 additions and 20 deletions

View File

@@ -61,6 +61,8 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes:
if pp_size <= 1:
break
if 'pipe_split' in node.name:
continue
accumulate_node_size += node.node_size
if accumulate_node_size >= partition_size:
accumulate_node_size = 0