[autoparallel] fix linear logical convert issue (#1857)

This commit is contained in:
YuliangLiu0306
2022-11-10 17:19:22 +08:00
committed by GitHub
parent c2947dadf1
commit 1b494ad73c
4 changed files with 40 additions and 10 deletions

View File

@@ -52,7 +52,6 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
if node.op == 'get_attr':
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
new_sharding_spec = target_sharding_specs[0]
user_node = node.strategies_vector.successor_nodes[0]
user_strategy = node.strategies_vector.successor_nodes[0].best_strategy
op_data_in_user = user_strategy.get_op_data_by_name(str(node))
origin_node_sharding_spec_dict[index] = new_sharding_spec