[autoparallel] change the following nodes strategies generation logic (#1636)

* [autoparallel] change the following nodes strategies generation logic

* fix unit test
This commit is contained in:
YuliangLiu0306
2022-09-27 11:20:52 +08:00
committed by GitHub
parent 59f100510a
commit 03978aad45
2 changed files with 9 additions and 8 deletions

View File

@@ -48,22 +48,23 @@ class UnaryElementwiseHandler(OperatorHandler):
# For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise function.
sharding_spec_checklist = []
for strategy in self.input_node.strategies_vector:
for index, strategy in enumerate(self.input_node.strategies_vector):
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
if input_sharding_spec in sharding_spec_checklist:
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
try:
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict)
except AssertionError as e:
warnings.warn(f'{e}')
continue
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}_{index}'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = self.output_data.numel()
memory_cost = 0