[Pipeline Middleware ] Fix deadlock when num_microbatch=num_stage (#2156)

* add splitter

* polish code

* remove comment

* fix async nan by moving to cpu first

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2022-12-23 11:38:43 +08:00
committed by GitHub
parent 937f404253
commit 59e343328d
4 changed files with 84 additions and 58 deletions

View File

@@ -9,6 +9,30 @@ def pipe_split():
pass
def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
"""
In avgnode_split_pass, simpliy split graph by node number.
"""
mod_graph = gm.graph
avg_num_node = len(mod_graph.nodes) // pp_size
accumulate_num_node = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
accumulate_num_node += 1
if accumulate_num_node >= avg_num_node:
accumulate_num_node = 0
pp_size -= 1
if node.next.op == 'output':
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
else:
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
return gm
def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
"""
In balanced_split_pass, we split module by the size of parameters(weights+bias).