[autoparallel] added addbmm handler (#1751)

This commit is contained in:
Frank Lee
2022-10-21 18:55:48 +08:00
committed by GitHub
parent 980ed21723
commit 262652c8bc
8 changed files with 353 additions and 35 deletions

View File

@@ -1,6 +1,8 @@
import torch
from enum import Enum, auto
from typing import List
import torch
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape']
@@ -56,6 +58,9 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
assert logical_num_dims >= physical_num_dims, \
'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!'
# track the dim and its broadcasting type
logical_dim_broadcast_info = {}