[autoparallel]add essential CommActions for broadcast oprands (#1793)

This commit is contained in:
YuliangLiu0306
2022-11-04 18:36:42 +08:00
committed by GitHub
parent 05ce3d369f
commit e34e850a4c
9 changed files with 102 additions and 24 deletions

View File

@@ -2,10 +2,21 @@ from enum import Enum, auto
from typing import List
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
)
from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape']
__all__ = [
'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape',
'comm_actions_for_oprands'
]
class BroadcastType(Enum):
@@ -86,8 +97,11 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
# recording the sharding dimensions removed during logical shape converting to physical one
removed_dims = []
if list(logical_shape) == list(physical_shape):
return logical_sharding_spec
return logical_sharding_spec, removed_dims
# get the number of dimensions
logical_num_dims = len(logical_shape)
@@ -104,7 +118,7 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
logical_broadcast_type = logical_dim_broadcast_info[shape_dim]
if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
pass
removed_dims.extend(mesh_dim)
else:
# get the corresponding physical dim
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
@@ -114,4 +128,33 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
entire_shape=physical_shape,
dim_partition_dict=physical_dim_partition)
return physical_sharding_spec
return physical_sharding_spec, removed_dims
def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData,
sharding_spec: ShardingSpec) -> CommAction:
"""
This method is used to generate communication actions for oprands which lose information
during convert logical shape to physical shape.
"""
if len(removed_dims) == 1:
# if list length is 1, extract element from list to avoid using flatten device mesh
removed_dims = removed_dims[0]
comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
sharding_spec=sharding_spec,
logical_process_axis=removed_dims)
if op_data.type == OperationDataType.PARAM:
comm_type = CommType.HOOK
else:
comm_type = CommType.BEFORE
arg_index = -1
for index, arg in enumerate(node.args):
if op_data.name == str(arg):
arg_index = index
assert arg_index >= 0, f'op_data should be an argument of node.'
comm_action = CommAction(
comm_spec=comm_spec,
comm_type=comm_type,
arg_index=arg_index,
)
return comm_action