From b0f7c8bde8d64214cd005d993ea54c9ad6e38630 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Fri, 28 Oct 2022 09:57:43 +0800 Subject: [PATCH] [autoparallel] update CommSpec to CommActions (#1768) * [autoparallel] update CommSpec to CommActions * polish code --- .../node_handler/linear_handler.py | 9 +- .../strategy/batch_norm_generator.py | 28 +- .../strategy/getitem_generator.py | 15 +- .../strategy/layer_norm_generator.py | 27 +- .../strategy/matmul_strategy_generator.py | 304 ++++++++++++------ colossalai/tensor/comm_spec.py | 4 +- .../test_node_handler/test_linear_handler.py | 2 + 7 files changed, 267 insertions(+), 122 deletions(-) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py index 62210ebe9..d1ea84b39 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -202,16 +202,17 @@ class LinearFunctionHandler(NodeHandler): mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if self.node.args[2] is not None: + if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None: # check if the other operand is a parameter - if isinstance(self.node.args[2]._meta_data, torch.nn.parameter.Parameter): + if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): data_type = OperationDataType.PARAM else: data_type = OperationDataType.ARG - physical_bias_operand = OperationData(name=str(self.node.args[2]), + physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]), type=data_type, - data=self.node.args[2]._meta_data) + data=self.node.kwargs["bias"]._meta_data) mapping['bias'] = physical_bias_operand + return mapping def post_process(self, strategy: ShardingStrategy): diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py index e648fff39..b3769ccd6 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py @@ -3,7 +3,12 @@ import operator from functools import reduce from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator @@ -204,12 +209,13 @@ class BatchNormStrategyGenerator(StrategyGenerator): # For SyncBN case, we don't need to do communication for weight and bias. # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation # to SyncBN operation instead of inserting a communication node. - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim_0) + logical_process_axis=mesh_dim_0, + comm_type=CommType.AFTER) - communication_action_mapping = {"output": output_comm_spec} + communication_action_mapping = {"output": output_comm_action} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -238,12 +244,13 @@ class BatchNormStrategyGenerator(StrategyGenerator): # For SyncBN case, we don't need to do communication for gradients of weight and bias. # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation # to SyncBN operation instead of inserting a communication node. - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.AFTER) - communication_action_mapping = {"output": output_comm_spec} + communication_action_mapping = {"output": output_comm_action} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -282,12 +289,13 @@ class BatchNormStrategyGenerator(StrategyGenerator): # For SyncBN case, we don't need to do communication for gradients of weight and bias. # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation # to SyncBN operation instead of inserting a communication node. - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=[mesh_dim_0]) + logical_process_axis=[mesh_dim_0], + comm_type=CommType.AFTER) - communication_action_mapping = {"output": output_comm_spec} + communication_action_mapping = {"output": output_comm_action} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py index 8b8080b75..532df083a 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py @@ -1,7 +1,12 @@ import copy from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import FollowingStrategyGenerator @@ -83,11 +88,13 @@ class TensorStrategyGenerator(GetItemStrategyGenerator): } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) if gather_input: - input_communication_spec = self.get_communication_spec( + input_communication_action = self.get_communication_action( sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - logical_process_axis=logical_process_axis) - communication_action_mapping["input"] = input_communication_spec + logical_process_axis=logical_process_axis, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping["input"] = input_communication_action name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py index 8c7d11437..38aa41fe4 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py @@ -3,9 +3,16 @@ import operator from functools import reduce from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) -from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, - enumerate_all_possible_2d_sharding) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, +) from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator @@ -107,18 +114,20 @@ class LayerNormGenerator(StrategyGenerator): total_mesh_dim_list = total_mesh_dim_list[0] communication_action_mapping = {} - other_comm_spec = self.get_communication_spec( + other_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=total_mesh_dim_list) - communication_action_mapping["other"] = other_comm_spec + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.HOOK) + communication_action_mapping["other"] = other_comm_action if self.has_bias: - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=total_mesh_dim_list) - communication_action_mapping["bias"] = bias_comm_spec + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.HOOK) + communication_action_mapping["bias"] = bias_comm_action strategy = self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index be2a95098..11b883873 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -1,8 +1,14 @@ import operator +from ast import arg from functools import reduce from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception from colossalai.tensor.shape_consistency import CollectiveCommPattern @@ -77,11 +83,12 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication action - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim) - communication_action_mapping = {"output": output_comm_spec} + logical_process_axis=mesh_dim, + comm_type=CommType.AFTER) + communication_action_mapping = {"output": output_comm_action} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -124,15 +131,35 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication action - other_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping['other'], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim) - bias_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping['bias'], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim) - communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec} + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=1) + if self.has_bias: + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=2) + communication_action_mapping = {'other': other_comm_action, 'bias': bias_comm_action} + return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -227,24 +254,45 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # set communication action communication_action_mapping = {} - input_comm_spec = self.get_communication_spec( + input_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_1) - other_comm_spec = self.get_communication_spec( - sharding_spec_mapping["output"], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) - communication_action_mapping['input'] = input_comm_spec - communication_action_mapping['other'] = other_comm_spec + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping['input'] = input_comm_action + communication_action_mapping['other'] = other_comm_action if self.has_bias: - bias_comm_spec = self.get_communication_spec( - sharding_spec_mapping["bias"], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - communication_action_mapping['bias'] = bias_comm_spec + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping['bias'] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -273,24 +321,45 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # get communication action mapping communication_action_mapping = {} - input_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping["input"], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - output_comm_spec = self.get_communication_spec( + + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim_1) + logical_process_axis=mesh_dim_1, + comm_type=CommType.AFTER) - communication_action_mapping['input'] = input_comm_spec - communication_action_mapping['output'] = output_comm_spec + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping['other'] = other_comm_action + communication_action_mapping['output'] = output_comm_action if self.has_bias: - bias_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping["bias"], - communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim_1) - communication_action_mapping['bias'] = bias_comm_spec + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping['bias'] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -320,16 +389,19 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # get communication actions communication_action_mapping = {} - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim_0) - input_comm_spec = self.get_communication_spec( + logical_process_axis=mesh_dim_0, + comm_type=CommType.AFTER) + input_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_1) - communication_action_mapping["input"] = input_comm_spec - communication_action_mapping['output'] = output_comm_spec + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping["input"] = input_comm_action + communication_action_mapping['output'] = output_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -354,12 +426,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # get communication action communication_action_mapping = {} - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim) + logical_process_axis=mesh_dim, + comm_type=CommType.AFTER) - communication_action_mapping['output'] = output_comm_spec + communication_action_mapping['output'] = output_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -386,12 +459,14 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # get communication actions communication_action_mapping = {} - input_comm_spec = self.get_communication_spec( + input_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim) + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=0) - communication_action_mapping['input'] = input_comm_spec + communication_action_mapping['input'] = input_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -414,18 +489,36 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # get communication action communication_action_mapping = {} - other_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping['other'], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping['other'] = other_comm_spec + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=1) + communication_action_mapping['other'] = other_comm_action if self.has_bias: - bias_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping['bias'], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping['bias'] = bias_comm_spec + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping['bias'] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -449,11 +542,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # get communication action communication_action_mapping = {} - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping['output'] = output_comm_spec + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.AFTER) + communication_action_mapping['output'] = output_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -480,11 +574,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # get communication action communication_action_mapping = {} - input_comm_spec = self.get_communication_spec( + input_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping['input'] = input_comm_spec + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['input'] = input_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -516,8 +612,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): [b, i, k] x [b, k, j] -> [b, i, j] The bias term is considered to have a 2D logical shape. + + Note: This class will be used to generate strategies for torch.bmm + and torch.addbmm. However, the result of torch.addbmm is not correct, + some extra runtime apply actions are required to keep numerical correctness. """ + # TODO: torch.addbmm correctness issue need to be fixed. def __init__(self, *args, **kwargs): self.squeeze_batch_dim = False super().__init__(*args, **kwargs) @@ -566,16 +667,16 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): self._pop_batch_dim_sharding_for_output(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) - print(sharding_spec_mapping) - # get communication actions communication_action_mapping = {} if self.has_bias: - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['bias'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim) - communication_action_mapping['bias'] = bias_comm_spec + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -602,11 +703,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): # get communication actions communication_action_mapping = {} if self.has_bias: - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['bias'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping['bias'] = bias_comm_spec + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -637,18 +740,24 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): # get communication actions communication_action_mapping = {} - other_comm_spec = self.get_communication_spec( + other_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['other'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_1) - communication_action_mapping['other'] = other_comm_spec + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=1) + communication_action_mapping['other'] = other_comm_action if self.has_bias: - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['bias'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping['bias'] = bias_comm_spec + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action + # for addbmm case, other is the third argument instead of second. + communication_action_mapping['other'].arg_index += 1 return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -679,18 +788,23 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): # get communication actions communication_action_mapping = {} - input_comm_spec = self.get_communication_spec( + input_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_1) - communication_action_mapping['input'] = input_comm_spec + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['input'] = input_comm_action if self.has_bias: - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['bias'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - communication_action_mapping['bias'] = bias_comm_spec + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE) + communication_action_mapping['bias'] = bias_comm_action + # for addbmm case, other is the second argument instead of first. + communication_action_mapping['input'].arg_index += 1 return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -719,18 +833,21 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): # get communication actions communication_action_mapping = {} - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim_1) - communication_action_mapping['output'] = output_comm_spec + logical_process_axis=mesh_dim_1, + comm_type=CommType.AFTER) + communication_action_mapping['output'] = output_comm_action if self.has_bias: - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['bias'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - communication_action_mapping['bias'] = bias_comm_spec + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -771,6 +888,5 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): # split two batch dim strategy_list.append(self.split_two_batch_dim(0, 1)) - strategy_list.append(self.split_two_batch_dim(1, 0)) return strategy_list diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index 617057a4f..a0775d0bc 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -41,7 +41,7 @@ def _split(tensor, comm_spec): dim = comm_spec.shard_dim length = tensor.shape[comm_spec.shard_dim] // len(rank_list) start = length * rank_list.index(dist.get_rank()) - output = torch.narrow(tensor, dim, start, length) + output = torch.narrow(tensor, dim, start, length).contiguous() return output @@ -76,6 +76,8 @@ def _all_reduce(tensor, comm_spec): process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] for rank_list, process_group in process_groups_list: if dist.get_rank() in rank_list: + if not tensor.is_contiguous(): + tensor = tensor.contiguous() dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group) return tensor diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index 290d73f5a..52284f8e5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ) from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.utils import parameterize @@ -109,6 +110,7 @@ def test_linear_module_handler(bias): assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] +@run_on_environment_flag(name='AUTO_PARALLEL') @parameterize('bias', [True, False]) def test_linear_function_handler(bias): model = nn.Linear(16, 32, bias=bias).to('meta')