[fx] support module with bias addition (#1780)

* [autoparallel] refactor tracer to fix bias addition issue

* [fx] support module with bias addition

* create bias_addition_module

* refactor file structure

* polish code

* fix unit test
This commit is contained in:
YuliangLiu0306
2022-11-01 22:53:51 +08:00
committed by GitHub
parent f3f19a5c47
commit e859380bf7
41 changed files with 624 additions and 259 deletions

View File

@@ -1,7 +1,7 @@
import torch
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
@@ -37,6 +37,21 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
else:
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
if pp_size > 1:
node_counter = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
if node.op == 'placeholder':
continue
elif node_counter == 0:
node_counter += 1
else:
pp_size -= 1
node_counter = 0
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
return gm