[autoparallel] add non_split linear strategy (#2078)

* [autoparallel] add non_split linear stategy

* polish
This commit is contained in:
YuliangLiu0306
2022-12-06 10:19:33 +08:00
committed by GitHub
parent cf0268da93
commit cdf537a648
4 changed files with 120 additions and 24 deletions

View File

@@ -204,9 +204,15 @@ class ShardingStrategy:
def _deepcopy_dict_vals(data: Dict):
return {k: deepcopy(v) for k, v in data.items()}
sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs else None
communication_actions = _deepcopy_dict_vals(self.communication_actions) if self.communication_actions else None
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs else None
sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs is not None else None
# We need to deepcopy it when self.communication_actions is not None, instead of checking its __bool__ value.
# Consider the examples below:
# If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False.
# In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items.
communication_actions = _deepcopy_dict_vals(
self.communication_actions) if self.communication_actions is not None else None
# same reason as communication_actions
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None
compute_cost = deepcopy(self.compute_cost)
communication_cost = deepcopy(self.communication_cost)
memory_cost = deepcopy(self.memory_cost)