[autoparallel] add non_split linear strategy (#2078)

* [autoparallel] add non_split linear stategy

* polish
This commit is contained in:
YuliangLiu0306
2022-12-06 10:19:33 +08:00
committed by GitHub
parent cf0268da93
commit cdf537a648
4 changed files with 120 additions and 24 deletions

View File

@@ -263,6 +263,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# RR = RR x RR
strategies.append(self.non_split())
return strategies
@ignore_sharding_exception
@@ -665,6 +668,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x RR'
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {},
"bias": {},
"output": {},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def validate(self) -> bool:
assert "input" in self.op_data
assert "other" in self.op_data

View File

@@ -204,9 +204,15 @@ class ShardingStrategy:
def _deepcopy_dict_vals(data: Dict):
return {k: deepcopy(v) for k, v in data.items()}
sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs else None
communication_actions = _deepcopy_dict_vals(self.communication_actions) if self.communication_actions else None
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs else None
sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs is not None else None
# We need to deepcopy it when self.communication_actions is not None, instead of checking its __bool__ value.
# Consider the examples below:
# If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False.
# In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items.
communication_actions = _deepcopy_dict_vals(
self.communication_actions) if self.communication_actions is not None else None
# same reason as communication_actions
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None
compute_cost = deepcopy(self.compute_cost)
communication_cost = deepcopy(self.communication_cost)
memory_cost = deepcopy(self.memory_cost)