mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 21:22:49 +00:00
[autoparallel] change the following nodes strategies generation logic (#1636)
* [autoparallel] change the following nodes strategies generation logic * fix unit test
This commit is contained in:
@@ -48,22 +48,23 @@ class UnaryElementwiseHandler(OperatorHandler):
|
||||
# For element-wise function, we keep the sharding spec of output node same as
|
||||
# the input. Therefore, the different strategies of input node with same
|
||||
# output sharding spec will generate same strategy for element-wise function.
|
||||
sharding_spec_checklist = []
|
||||
for strategy in self.input_node.strategies_vector:
|
||||
|
||||
for index, strategy in enumerate(self.input_node.strategies_vector):
|
||||
# It looks a little bit confusing, the input of the processing node
|
||||
# is the output of the input_node.
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
if input_sharding_spec in sharding_spec_checklist:
|
||||
continue
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
|
||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
try:
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict)
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
continue
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
||||
# add index into name to pass the duplicated check
|
||||
# we keep same strategies with different name for node merging, and it will not increase the searching space,
|
||||
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}_{index}'
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = self.output_data.numel()
|
||||
memory_cost = 0
|
||||
|
Reference in New Issue
Block a user