[autoparallel] update binary elementwise handler (#2451)

* [autoparallel] update binary elementwise handler

* polish
This commit is contained in:
YuliangLiu0306
2023-01-12 09:35:10 +08:00
committed by GitHub
parent c9ec5190a0
commit 8221fd7485
3 changed files with 74 additions and 23 deletions

View File

@@ -32,20 +32,32 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
return OperationDataType.ARG
def _get_arg_value(idx):
non_tensor = False
if isinstance(self.node.args[idx], Node):
meta_data = self.node.args[idx]._meta_data
# The meta_data of node type argument could also possibly be a non-tensor object.
if not isinstance(meta_data, torch.Tensor):
assert isinstance(meta_data, (int, float))
meta_data = torch.Tensor([meta_data]).to('meta')
non_tensor = True
else:
# this is in fact a real data like int 1
# but we can deem it as meta data
# as it won't affect the strategy generation
assert isinstance(self.node.args[idx], (int, float))
meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
return meta_data
non_tensor = True
input_meta_data = _get_arg_value(0)
other_meta_data = _get_arg_value(1)
return meta_data, non_tensor
input_meta_data, non_tensor_input = _get_arg_value(0)
other_meta_data, non_tensor_other = _get_arg_value(1)
output_meta_data = self.node._meta_data
# we need record op_data with non-tensor data in this list,
# and filter the non-tensor op_data in post_process.
self.non_tensor_list = []
# assert False
input_op_data = OperationData(name=str(self.node.args[0]),
type=_get_op_data_type(input_meta_data),
data=input_meta_data,
@@ -58,6 +70,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
type=OperationDataType.OUTPUT,
data=output_meta_data,
logical_shape=bcast_shape)
if non_tensor_input:
self.non_tensor_list.append(input_op_data)
if non_tensor_other:
self.non_tensor_list.append(other_op_data)
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
return mapping
@@ -73,9 +89,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
op_data_mapping = self.get_operation_data_mapping()
for op_name, op_data in op_data_mapping.items():
if not isinstance(op_data.data, torch.Tensor):
if op_data in self.non_tensor_list:
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
strategy.sharding_specs.pop(op_data)
else:
# convert the logical sharding spec to physical sharding spec if broadcast
# e.g. torch.rand(4, 4) + torch.rand(4)