mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 19:17:30 +00:00
[autoparallel] update binary elementwise handler (#2451)
* [autoparallel] update binary elementwise handler * polish
This commit is contained in:
@@ -32,20 +32,32 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
||||
return OperationDataType.ARG
|
||||
|
||||
def _get_arg_value(idx):
|
||||
non_tensor = False
|
||||
if isinstance(self.node.args[idx], Node):
|
||||
meta_data = self.node.args[idx]._meta_data
|
||||
# The meta_data of node type argument could also possibly be a non-tensor object.
|
||||
if not isinstance(meta_data, torch.Tensor):
|
||||
assert isinstance(meta_data, (int, float))
|
||||
meta_data = torch.Tensor([meta_data]).to('meta')
|
||||
non_tensor = True
|
||||
|
||||
else:
|
||||
# this is in fact a real data like int 1
|
||||
# but we can deem it as meta data
|
||||
# as it won't affect the strategy generation
|
||||
assert isinstance(self.node.args[idx], (int, float))
|
||||
meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
|
||||
return meta_data
|
||||
non_tensor = True
|
||||
|
||||
input_meta_data = _get_arg_value(0)
|
||||
other_meta_data = _get_arg_value(1)
|
||||
return meta_data, non_tensor
|
||||
|
||||
input_meta_data, non_tensor_input = _get_arg_value(0)
|
||||
other_meta_data, non_tensor_other = _get_arg_value(1)
|
||||
output_meta_data = self.node._meta_data
|
||||
|
||||
# we need record op_data with non-tensor data in this list,
|
||||
# and filter the non-tensor op_data in post_process.
|
||||
self.non_tensor_list = []
|
||||
# assert False
|
||||
input_op_data = OperationData(name=str(self.node.args[0]),
|
||||
type=_get_op_data_type(input_meta_data),
|
||||
data=input_meta_data,
|
||||
@@ -58,6 +70,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
||||
type=OperationDataType.OUTPUT,
|
||||
data=output_meta_data,
|
||||
logical_shape=bcast_shape)
|
||||
if non_tensor_input:
|
||||
self.non_tensor_list.append(input_op_data)
|
||||
if non_tensor_other:
|
||||
self.non_tensor_list.append(other_op_data)
|
||||
|
||||
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
|
||||
return mapping
|
||||
@@ -73,9 +89,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
|
||||
for op_name, op_data in op_data_mapping.items():
|
||||
if not isinstance(op_data.data, torch.Tensor):
|
||||
if op_data in self.non_tensor_list:
|
||||
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
|
||||
strategy.sharding_specs.pop(op_data)
|
||||
|
||||
else:
|
||||
# convert the logical sharding spec to physical sharding spec if broadcast
|
||||
# e.g. torch.rand(4, 4) + torch.rand(4)
|
||||
|
Reference in New Issue
Block a user