mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[fx] support module with bias addition (#1780)
* [autoparallel] refactor tracer to fix bias addition issue * [fx] support module with bias addition * create bias_addition_module * refactor file structure * polish code * fix unit test
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.fx.passes.split_module import split_module
|
||||
|
||||
|
||||
@@ -37,6 +37,21 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
else:
|
||||
with mod_graph.inserting_after(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
if pp_size > 1:
|
||||
node_counter = 0
|
||||
for node in mod_graph.nodes:
|
||||
if pp_size <= 1:
|
||||
break
|
||||
if node.op == 'placeholder':
|
||||
continue
|
||||
elif node_counter == 0:
|
||||
node_counter += 1
|
||||
else:
|
||||
pp_size -= 1
|
||||
node_counter = 0
|
||||
with mod_graph.inserting_before(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
Reference in New Issue
Block a user