diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py index 798e677eb..5b600e735 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -3,10 +3,17 @@ from typing import Dict, List, Union import torch from torch.fx.node import Node -from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationData, + OperationDataType, + ShardingStrategy, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager from ..constants import BCAST_FUNC_OP -from ..utils import recover_sharding_spec_for_broadcast_shape +from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape from .node_handler import NodeHandler from .registry import operator_registry from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator @@ -81,6 +88,15 @@ class BinaryElementwiseHandler(NodeHandler): physical_shape = op_data.data.shape logical_shape = op_data.logical_shape sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) - sharding_spec = recover_sharding_spec_for_broadcast_shape(sharding_spec, logical_shape, physical_shape) + sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + sharding_spec, logical_shape, physical_shape) + strategy.sharding_specs[op_data] = sharding_spec + if len(removed_dims) > 0: + comm_action = comm_actions_for_oprands(node=self.node, + removed_dims=removed_dims, + op_data=op_data, + sharding_spec=sharding_spec) + strategy.communication_actions[op_data] = comm_action + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py index 09016d507..9e1d958e1 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py @@ -2,8 +2,10 @@ from typing import Dict, List, Union import torch -from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy -from ..utils import recover_sharding_spec_for_broadcast_shape +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager + +from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy +from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape from .node_handler import NodeHandler from .registry import operator_registry from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator @@ -91,7 +93,15 @@ class AddBMMFunctionHandler(NodeHandler): bias_physical_shape = bias_op_data.data.shape bias_logical_shape = bias_op_data.logical_shape bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) - bias_sharding_spec = recover_sharding_spec_for_broadcast_shape(bias_sharding_spec, bias_logical_shape, - bias_physical_shape) + bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + bias_sharding_spec, bias_logical_shape, bias_physical_shape) strategy.sharding_specs[bias_op_data] = bias_sharding_spec + + if len(removed_dims) > 0: + comm_action = comm_actions_for_oprands(node=self.node, + removed_dims=removed_dims, + op_data=bias_op_data, + sharding_spec=bias_sharding_spec) + strategy.communication_actions[bias_op_data] = comm_action + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py index 400c69693..5bc899049 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py @@ -213,7 +213,7 @@ class Broadcaster(BmmTransform): tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None] - physical_sharding_spec = recover_sharding_spec_for_broadcast_shape( + physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( logical_sharding_spec=sharding_spec, logical_shape=sharding_spec.entire_shape, physical_shape=tensor_shape_before_broadcast) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py index ebcd6c453..daf81f995 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py @@ -4,7 +4,7 @@ from typing import Dict, List import torch -from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, StrategiesVector) +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector from ..utils import recover_sharding_spec_for_broadcast_shape from .node_handler import NodeHandler from .registry import operator_registry @@ -81,8 +81,8 @@ class WhereHandler(NodeHandler): logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]] logical_shape = logical_op_data_mapping[key].logical_shape physical_shape = physical_op_data_mapping[key].logical_shape - physical_sharding_spec = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec, logical_shape, - physical_shape) + physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + logical_sharding_spec, logical_shape, physical_shape) strategy.sharding_specs.pop(logical_op_data_mapping[key]) strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}" diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py index 380464bcd..043147b9f 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py @@ -1,4 +1,10 @@ -from .broadcast import BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape +from .broadcast import ( + BroadcastType, + comm_actions_for_oprands, + get_broadcast_shape, + is_broadcastable, + recover_sharding_spec_for_broadcast_shape, +) from .factory import generate_resharding_costs, generate_sharding_spec from .misc import check_sharding_spec_validity, ignore_sharding_exception from .sharding import ( @@ -13,5 +19,5 @@ __all__ = [ 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', - 'enumerate_all_possible_2d_sharding', 'generate_sharding_size' + 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands' ] diff --git a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py index 3a3753b00..28aa55132 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py +++ b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py @@ -2,10 +2,21 @@ from enum import Enum, auto from typing import List import torch +from torch.fx.node import Node +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationData, + OperationDataType, +) +from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec from colossalai.tensor.sharding_spec import ShardingSpec -__all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape'] +__all__ = [ + 'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape', + 'comm_actions_for_oprands' +] class BroadcastType(Enum): @@ -86,8 +97,11 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe """ # if the two shapes are the same, no broadcast occurs # we directly return the current sharding spec + + # recording the sharding dimensions removed during logical shape converting to physical one + removed_dims = [] if list(logical_shape) == list(physical_shape): - return logical_sharding_spec + return logical_sharding_spec, removed_dims # get the number of dimensions logical_num_dims = len(logical_shape) @@ -104,7 +118,7 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe logical_broadcast_type = logical_dim_broadcast_info[shape_dim] if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE: - pass + removed_dims.extend(mesh_dim) else: # get the corresponding physical dim physical_dim = physical_num_dims - (logical_num_dims - shape_dim) @@ -114,4 +128,33 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe entire_shape=physical_shape, dim_partition_dict=physical_dim_partition) - return physical_sharding_spec + return physical_sharding_spec, removed_dims + + +def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData, + sharding_spec: ShardingSpec) -> CommAction: + """ + This method is used to generate communication actions for oprands which lose information + during convert logical shape to physical shape. + """ + if len(removed_dims) == 1: + # if list length is 1, extract element from list to avoid using flatten device mesh + removed_dims = removed_dims[0] + comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + sharding_spec=sharding_spec, + logical_process_axis=removed_dims) + if op_data.type == OperationDataType.PARAM: + comm_type = CommType.HOOK + else: + comm_type = CommType.BEFORE + arg_index = -1 + for index, arg in enumerate(node.args): + if op_data.name == str(arg): + arg_index = index + assert arg_index >= 0, f'op_data should be an argument of node.' + comm_action = CommAction( + comm_spec=comm_spec, + comm_type=comm_type, + arg_index=arg_index, + ) + return comm_action diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py index fb8f46b5e..21695f6b5 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py @@ -39,8 +39,8 @@ class BiasAdditionConv(BiasAdditionModule): This method is used to reshape the bias node in order to make bias and output of non-bias convolution broadcastable. """ - bias_shape = [1] * dimensions - bias_shape[1] = -1 + bias_shape = [1] * (dimensions - 1) + bias_shape[0] = -1 bias_reshape_node_kind = 'call_method' bias_reshape_node_target = 'view' bias_reshape_node_args = (self.bias_proxy, bias_shape) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py index 4c35e7de5..560758749 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py @@ -1,7 +1,10 @@ import torch -from colossalai.auto_parallel.tensor_shard.utils import (get_broadcast_shape, is_broadcastable, - recover_sharding_spec_for_broadcast_shape) +from colossalai.auto_parallel.tensor_shard.utils import ( + get_broadcast_shape, + is_broadcastable, + recover_sharding_spec_for_broadcast_shape, +) from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.sharding_spec import ShardingSpec @@ -51,8 +54,8 @@ def test_recover_sharding_spec_for_broadcast_shape(): 1: [1] }, entire_shape=broadcast_shape) - physical_sharding_spec_for_x1 = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec_for_x1, - broadcast_shape, x1.shape) + physical_sharding_spec_for_x1, removed_dims = recover_sharding_spec_for_broadcast_shape( + logical_sharding_spec_for_x1, broadcast_shape, x1.shape) print(physical_sharding_spec_for_x1) assert physical_sharding_spec_for_x1.entire_shape == x1.shape diff --git a/tests/test_fx/test_tracer/test_bias_addition_module.py b/tests/test_fx/test_tracer/test_bias_addition_module.py index fbb7d1f3f..afa30a217 100644 --- a/tests/test_fx/test_tracer/test_bias_addition_module.py +++ b/tests/test_fx/test_tracer/test_bias_addition_module.py @@ -105,7 +105,7 @@ def test_conv_module(): assert weight_node._meta_data.shape == (6, 3, 2, 2) assert bias_node._meta_data.shape == (6,) assert conv_node._meta_data.shape == (4, 6, 63, 63) - assert view_node._meta_data.shape == (1, 6, 1, 1) + assert view_node._meta_data.shape == (6, 1, 1) assert add_node._meta_data.shape == (4, 6, 63, 63)