mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[autoparallel] fix bias addition module (#1800)
This commit is contained in:
@@ -93,7 +93,7 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||
# substitute the origin node with shape_consistency_node
|
||||
origin_index_args = new_args.index(node)
|
||||
new_args[origin_index_args] = shape_consistency_node
|
||||
user_node.args = new_args
|
||||
user_node.args = tuple(new_args)
|
||||
elif str(node) in new_kwargs:
|
||||
# substitute the origin node with shape_consistency_node
|
||||
new_kwargs[str(node)] = shape_consistency_node
|
||||
@@ -118,10 +118,12 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
for op_data, comm_action in comm_actions.items():
|
||||
|
||||
if op_data.type == OperationDataType.PARAM:
|
||||
if comm_action.comm_type == CommType.HOOK:
|
||||
continue
|
||||
if comm_action.comm_type == CommType.BEFORE:
|
||||
if comm_action.key_for_kwarg is not None:
|
||||
if op_data.type == OperationDataType.OUTPUT:
|
||||
comm_object = node
|
||||
elif comm_action.key_for_kwarg is not None:
|
||||
comm_object = node.kwargs[comm_action.key_for_kwarg]
|
||||
else:
|
||||
comm_object = node.args[comm_action.arg_index]
|
||||
@@ -140,7 +142,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
new_args = list(node.args)
|
||||
new_args[comm_action.arg_index] = comm_spec_apply_node
|
||||
node.args = new_args
|
||||
node.args = tuple(new_args)
|
||||
|
||||
elif comm_action.comm_type == CommType.AFTER:
|
||||
with mod_graph.inserting_after(node):
|
||||
@@ -163,7 +165,6 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
new_kwargs[str(node)] = comm_spec_apply_node
|
||||
user.kwargs = new_kwargs
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user