mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[autoparallel] added matmul handler (#1763)
* [autoparallel] added matmul handler * polish code
This commit is contained in:
@@ -44,21 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
|
||||
return dims[::-1]
|
||||
|
||||
|
||||
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
|
||||
physical_shape: torch.Size) -> ShardingSpec:
|
||||
"""
|
||||
This function computes the sharding spec for the physical shape of a broadcast tensor.
|
||||
|
||||
Args:
|
||||
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
|
||||
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
|
||||
physical_shape (torch.Size): the shape of the tensor before broadcasting
|
||||
"""
|
||||
# if the two shapes are the same, no broadcast occurs
|
||||
# we directly return the current sharding spec
|
||||
if list(logical_shape) == list(physical_shape):
|
||||
return logical_sharding_spec
|
||||
|
||||
def get_broadcast_dim_info(logical_shape, physical_shape):
|
||||
# get the number of dimensions
|
||||
logical_num_dims = len(logical_shape)
|
||||
physical_num_dims = len(physical_shape)
|
||||
@@ -85,6 +71,31 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
||||
else:
|
||||
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
|
||||
|
||||
return logical_dim_broadcast_info
|
||||
|
||||
|
||||
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
|
||||
physical_shape: torch.Size) -> ShardingSpec:
|
||||
"""
|
||||
This function computes the sharding spec for the physical shape of a broadcast tensor.
|
||||
|
||||
Args:
|
||||
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
|
||||
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
|
||||
physical_shape (torch.Size): the shape of the tensor before broadcasting
|
||||
"""
|
||||
# if the two shapes are the same, no broadcast occurs
|
||||
# we directly return the current sharding spec
|
||||
if list(logical_shape) == list(physical_shape):
|
||||
return logical_sharding_spec
|
||||
|
||||
# get the number of dimensions
|
||||
logical_num_dims = len(logical_shape)
|
||||
physical_num_dims = len(physical_shape)
|
||||
|
||||
# get the broadcast info
|
||||
logical_dim_broadcast_info = get_broadcast_dim_info(logical_shape, physical_shape)
|
||||
|
||||
# generate the sharding spec for the physical shape
|
||||
physical_dim_partition = {}
|
||||
logical_dim_partition = logical_sharding_spec.dim_partition_dict
|
||||
|
Reference in New Issue
Block a user