[autoparallel] change the merge node logic (#1533)

This commit is contained in:
YuliangLiu0306
2022-09-07 11:18:19 +08:00
committed by GitHub
parent ae71036cd2
commit 44c866a3e3
3 changed files with 71 additions and 43 deletions

View File

@@ -82,7 +82,8 @@ class CommSpec:
if self.comm_pattern == CollectiveCommPattern.ALLREDUCE:
return self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.SHARD:
return 0
# give a tiny cost to shard
return 10
raise RuntimeError(f"Could not find a matching CollectiveCommPattern for {self.comm_pattern}.")
def covert_spec_to_action(self, tensor):