mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[autoparallel] added binary elementwise node handler (#1758)
* [autoparallel] added binary elementwise node handler * polish code
This commit is contained in:
@@ -54,6 +54,11 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
||||
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
|
||||
physical_shape (torch.Size): the shape of the tensor before broadcasting
|
||||
"""
|
||||
# if the two shapes are the same, no broadcast occurs
|
||||
# we directly return the current sharding spec
|
||||
if list(logical_shape) == list(physical_shape):
|
||||
return logical_sharding_spec
|
||||
|
||||
# get the number of dimensions
|
||||
logical_num_dims = len(logical_shape)
|
||||
physical_num_dims = len(physical_shape)
|
||||
|
Reference in New Issue
Block a user