fixed layout converter caching and updated tester

This commit is contained in:
Edenzzzz
2024-03-26 17:22:27 +08:00
parent a7790a92e8
commit 61da3fbc52
2 changed files with 13 additions and 3 deletions

View File

@@ -440,7 +440,10 @@ class LayoutConverter(metaclass=SingletonMeta):
total_steps = 0
transform_path = []
comm_action_sequence: List[CommSpec] = []
spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))
src_shape = source_layout.get_sharded_shape_per_device()
dst_shape = target_layout.get_sharded_shape_per_device()
spec_pairs = ((str(source_spec.sharding_sequence), src_shape), (str(target_spec.sharding_sequence), dst_shape))
if spec_pairs in self.cached_solution:
# Solution Cache hit