mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[autoparallel] added addbmm handler (#1751)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import torch
|
||||
from enum import Enum, auto
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape']
|
||||
@@ -56,6 +58,9 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
||||
logical_num_dims = len(logical_shape)
|
||||
physical_num_dims = len(physical_shape)
|
||||
|
||||
assert logical_num_dims >= physical_num_dims, \
|
||||
'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!'
|
||||
|
||||
# track the dim and its broadcasting type
|
||||
logical_dim_broadcast_info = {}
|
||||
|
||||
|
Reference in New Issue
Block a user