mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Pipeline Middleware ] Fix deadlock when num_microbatch=num_stage (#2156)
* add splitter * polish code * remove comment * fix async nan by moving to cpu first Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
@@ -9,6 +9,30 @@ def pipe_split():
|
||||
pass
|
||||
|
||||
|
||||
def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
"""
|
||||
In avgnode_split_pass, simpliy split graph by node number.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
avg_num_node = len(mod_graph.nodes) // pp_size
|
||||
accumulate_num_node = 0
|
||||
for node in mod_graph.nodes:
|
||||
if pp_size <= 1:
|
||||
break
|
||||
accumulate_num_node += 1
|
||||
if accumulate_num_node >= avg_num_node:
|
||||
accumulate_num_node = 0
|
||||
pp_size -= 1
|
||||
if node.next.op == 'output':
|
||||
with mod_graph.inserting_before(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
else:
|
||||
with mod_graph.inserting_after(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
||||
def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
"""
|
||||
In balanced_split_pass, we split module by the size of parameters(weights+bias).
|
||||
|
Reference in New Issue
Block a user