[autoparallel] remove redundancy comm node (#1893)

This commit is contained in:
YuliangLiu0306
2022-11-15 10:53:41 +08:00
committed by GitHub
parent 9183e0dec5
commit 36c0f3ea5b
5 changed files with 23 additions and 20 deletions

View File

@@ -81,6 +81,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
continue
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,