[autoparallel] fix bias addition module (#1800)

This commit is contained in:
YuliangLiu0306
2022-11-08 16:21:25 +08:00
committed by GitHub
parent 6e9730d7ab
commit f6032ddb17
9 changed files with 438 additions and 20 deletions

View File

@@ -93,7 +93,7 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
# substitute the origin node with shape_consistency_node
origin_index_args = new_args.index(node)
new_args[origin_index_args] = shape_consistency_node
user_node.args = new_args
user_node.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with shape_consistency_node
new_kwargs[str(node)] = shape_consistency_node
@@ -118,10 +118,12 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
if op_data.type == OperationDataType.PARAM:
if comm_action.comm_type == CommType.HOOK:
continue
if comm_action.comm_type == CommType.BEFORE:
if comm_action.key_for_kwarg is not None:
if op_data.type == OperationDataType.OUTPUT:
comm_object = node
elif comm_action.key_for_kwarg is not None:
comm_object = node.kwargs[comm_action.key_for_kwarg]
else:
comm_object = node.args[comm_action.arg_index]
@@ -140,7 +142,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = new_args
node.args = tuple(new_args)
elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
@@ -163,7 +165,6 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
return gm