fixed layout converter caching and updated tester

This commit is contained in:
Edenzzzz
2024-03-26 17:22:27 +08:00
parent a7790a92e8
commit 61da3fbc52
2 changed files with 13 additions and 3 deletions

View File

@@ -123,8 +123,15 @@ def check_layout_converting(rank, world_size, port):
assert comm_action_sequence[2].logical_process_axis == 1
# checkout chached_spec_pairs_transform_path
assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][0] == transform_path
assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][1] == comm_action_sequence
src_shape = source_layout.get_sharded_shape_per_device()
dst_shape = target_layout.get_sharded_shape_per_device()
assert (
layout_converter.cached_solution[(("[R, S01, R]", src_shape), ("[S01, R, R]", dst_shape))][0] == transform_path
)
assert (
layout_converter.cached_solution[(("[R, S01, R]", src_shape), ("[S01, R, R]", dst_shape))][1]
== comm_action_sequence
)
comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout)