mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[autoparallel]add essential CommActions for broadcast oprands (#1793)
This commit is contained in:
@@ -2,10 +2,21 @@ from enum import Enum, auto
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
)
|
||||
from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape']
|
||||
__all__ = [
|
||||
'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape',
|
||||
'comm_actions_for_oprands'
|
||||
]
|
||||
|
||||
|
||||
class BroadcastType(Enum):
|
||||
@@ -86,8 +97,11 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
||||
"""
|
||||
# if the two shapes are the same, no broadcast occurs
|
||||
# we directly return the current sharding spec
|
||||
|
||||
# recording the sharding dimensions removed during logical shape converting to physical one
|
||||
removed_dims = []
|
||||
if list(logical_shape) == list(physical_shape):
|
||||
return logical_sharding_spec
|
||||
return logical_sharding_spec, removed_dims
|
||||
|
||||
# get the number of dimensions
|
||||
logical_num_dims = len(logical_shape)
|
||||
@@ -104,7 +118,7 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
||||
logical_broadcast_type = logical_dim_broadcast_info[shape_dim]
|
||||
|
||||
if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
|
||||
pass
|
||||
removed_dims.extend(mesh_dim)
|
||||
else:
|
||||
# get the corresponding physical dim
|
||||
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
|
||||
@@ -114,4 +128,33 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
||||
entire_shape=physical_shape,
|
||||
dim_partition_dict=physical_dim_partition)
|
||||
|
||||
return physical_sharding_spec
|
||||
return physical_sharding_spec, removed_dims
|
||||
|
||||
|
||||
def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData,
|
||||
sharding_spec: ShardingSpec) -> CommAction:
|
||||
"""
|
||||
This method is used to generate communication actions for oprands which lose information
|
||||
during convert logical shape to physical shape.
|
||||
"""
|
||||
if len(removed_dims) == 1:
|
||||
# if list length is 1, extract element from list to avoid using flatten device mesh
|
||||
removed_dims = removed_dims[0]
|
||||
comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
sharding_spec=sharding_spec,
|
||||
logical_process_axis=removed_dims)
|
||||
if op_data.type == OperationDataType.PARAM:
|
||||
comm_type = CommType.HOOK
|
||||
else:
|
||||
comm_type = CommType.BEFORE
|
||||
arg_index = -1
|
||||
for index, arg in enumerate(node.args):
|
||||
if op_data.name == str(arg):
|
||||
arg_index = index
|
||||
assert arg_index >= 0, f'op_data should be an argument of node.'
|
||||
comm_action = CommAction(
|
||||
comm_spec=comm_spec,
|
||||
comm_type=comm_type,
|
||||
arg_index=arg_index,
|
||||
)
|
||||
return comm_action
|
||||
|
Reference in New Issue
Block a user