diff --git a/tests/test_auto_parallel/test_shape_consistency_pass.py b/tests/test_auto_parallel/test_shape_consistency_pass.py index 2a7b745f8..6cb46c1de 100644 --- a/tests/test_auto_parallel/test_shape_consistency_pass.py +++ b/tests/test_auto_parallel/test_shape_consistency_pass.py @@ -65,6 +65,7 @@ def check_apply(rank, world_size, port): solution = list(ret[0]) sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh) shape_consistency_pass(gm) + gm.recompile() nodes = [node for node in gm.graph.nodes] # TODO: wrap the gm to avoid the influence of the user training code output = gm(input, sharding_spec_dict, origin_spec_dict)