mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[autoparallel] add non_split linear strategy (#2078)
* [autoparallel] add non_split linear stategy * polish
This commit is contained in:
@@ -263,6 +263,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
||||
# RS01 = RR x RS01
|
||||
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
|
||||
|
||||
# RR = RR x RR
|
||||
strategies.append(self.non_split())
|
||||
|
||||
return strategies
|
||||
|
||||
@ignore_sharding_exception
|
||||
@@ -665,6 +668,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def non_split(self):
|
||||
name = f'RR = RR x RR'
|
||||
|
||||
# get sharding spec
|
||||
dim_partition_dict_mapping = {
|
||||
"input": {},
|
||||
"other": {},
|
||||
"bias": {},
|
||||
"output": {},
|
||||
}
|
||||
|
||||
# We don't have to do anything special for bias here, because
|
||||
# the bias is already the same sharding spec as the output.
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# get communication action
|
||||
communication_action_mapping = {}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def validate(self) -> bool:
|
||||
assert "input" in self.op_data
|
||||
assert "other" in self.op_data
|
||||
|
@@ -204,9 +204,15 @@ class ShardingStrategy:
|
||||
def _deepcopy_dict_vals(data: Dict):
|
||||
return {k: deepcopy(v) for k, v in data.items()}
|
||||
|
||||
sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs else None
|
||||
communication_actions = _deepcopy_dict_vals(self.communication_actions) if self.communication_actions else None
|
||||
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs else None
|
||||
sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs is not None else None
|
||||
# We need to deepcopy it when self.communication_actions is not None, instead of checking its __bool__ value.
|
||||
# Consider the examples below:
|
||||
# If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False.
|
||||
# In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items.
|
||||
communication_actions = _deepcopy_dict_vals(
|
||||
self.communication_actions) if self.communication_actions is not None else None
|
||||
# same reason as communication_actions
|
||||
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None
|
||||
compute_cost = deepcopy(self.compute_cost)
|
||||
communication_cost = deepcopy(self.communication_cost)
|
||||
memory_cost = deepcopy(self.memory_cost)
|
||||
|
Reference in New Issue
Block a user