[autoparallel] added binary elementwise node handler (#1758)

* [autoparallel] added binary elementwise node handler

* polish code
This commit is contained in:
Frank Lee
2022-10-25 14:32:01 +08:00
committed by GitHub
parent d2fc067231
commit f9a613d660
8 changed files with 395 additions and 8 deletions

View File

@@ -54,6 +54,11 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
if list(logical_shape) == list(physical_shape):
return logical_sharding_spec
# get the number of dimensions
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)