diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py index 3c232f131..43ea265d7 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py @@ -12,6 +12,7 @@ __all__ = ['ReshapeHandler'] @operator_registry.register(torch.reshape) @operator_registry.register(torch.Tensor.split) +@operator_registry.register(torch.split) @operator_registry.register(torch.flatten) @operator_registry.register(torch.Tensor.transpose) @operator_registry.register(torch.Tensor.permute) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py index 86f332d84..1f3812429 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py @@ -220,7 +220,9 @@ class BatchNormStrategyGenerator(StrategyGenerator): logical_process_axis=mesh_dim_0, comm_type=CommType.IMPLICIT) - communication_action_mapping = {"output": output_comm_action} + # TODO: Temporary solution has no communication cost, + # above action should be added after the SyncBN replace pass completed. + communication_action_mapping = {} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -256,7 +258,9 @@ class BatchNormStrategyGenerator(StrategyGenerator): logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.IMPLICIT) - communication_action_mapping = {"output": output_comm_action} + # TODO: Temporary solution has no communication cost, + # above action should be added after the SyncBN replace pass completed. + communication_action_mapping = {} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -302,7 +306,9 @@ class BatchNormStrategyGenerator(StrategyGenerator): logical_process_axis=[mesh_dim_0], comm_type=CommType.IMPLICIT) - communication_action_mapping = {"output": output_comm_action} + # TODO: Temporary solution has no communication cost, + # above action should be added after the SyncBN replace pass completed. + communication_action_mapping = {} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py index 532df083a..2795c8544 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py @@ -69,7 +69,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator): def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] - for strategy in self.predecessor_node.strategies_vector: + for index, strategy in enumerate(self.predecessor_node.strategies_vector): dim_partition_dict_mapping = {} communication_action_mapping = {} dim_partition_dict_for_input = strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict @@ -96,7 +96,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator): arg_index=0) communication_action_mapping["input"] = input_communication_action - name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}' strategy = self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -121,7 +121,7 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator): strategy_list = [] index = self.op_data["index"].data - for strategy in self.predecessor_node.strategies_vector: + for strategy_index, strategy in enumerate(self.predecessor_node.strategies_vector): # the sharding spec for input in this case is a tuple of ShardingSpec. sharding_spec_for_input = strategy.output_sharding_specs[self.op_data["input"]] dim_partition_dict_for_output = sharding_spec_for_input[index].dim_partition_dict @@ -132,8 +132,11 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator): } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) sharding_spec_mapping["input"] = sharding_spec_for_input - - name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' + input_sharding_info = f"get the {index} element from (" + for sharding_spec in sharding_spec_for_input: + input_sharding_info += f'{sharding_spec.sharding_sequence}, ' + input_sharding_info += ")" + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}' strategy = self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py index 95c8e2efa..fa941f2cc 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py @@ -1,9 +1,12 @@ import copy from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) -from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, - enumerate_all_possible_2d_sharding) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) from .strategy_generator import StrategyGenerator @@ -50,6 +53,7 @@ class WhereGenerator(StrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost + @ignore_sharding_exception def _generate_strategy_with_dim_partition(self, dim_partition): dim_partition_dict_mapping = { "condition": dim_partition, diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py index 334528019..cee43f2d0 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -14,6 +14,11 @@ __all__ = ['UnaryElementwiseHandler'] @operator_registry.register(torch.Tensor.type) @operator_registry.register(torch.abs) @operator_registry.register(torch.nn.ReLU) +# TODO: softmax need to be relocated +@operator_registry.register(torch.nn.functional.softmax) +@operator_registry.register(torch.nn.modules.dropout.Dropout) +@operator_registry.register(torch.Tensor.contiguous) +@operator_registry.register(torch.nn.functional.dropout) class UnaryElementwiseHandler(NodeHandler): """ A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op. diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py index daf81f995..6de2aaafd 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py @@ -57,24 +57,6 @@ class WhereHandler(NodeHandler): logical_operand.logical_shape = target_shape return logical_operand - def register_strategy(self, compute_resharding_cost: bool = False) -> StrategiesVector: - """ - Register different sharding strategies for the current node. - """ - strategy_generators = self.get_strategy_generator() - - for generator in strategy_generators: - strategies = generator.generate() - strategies_vector = map(self.post_process, strategies) - # compute the resharding costs based on the previous node - # strategies if specified - if compute_resharding_cost: - strategies = list(map(self.update_resharding_cost, strategies)) - self.strategies_vector.extend(strategies) - - self.strategies_vector = list(strategies_vector) - return self.strategies_vector - def post_process(self, strategy: ShardingStrategy): logical_op_data_mapping, physical_op_data_mapping = self.get_operation_data_mapping() for key in logical_op_data_mapping.keys(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index 5f7c469bc..4e01ed243 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -3,6 +3,8 @@ import torch.nn as nn from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler +from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler +from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer @@ -10,7 +12,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.testing.pytest_wrapper import run_on_environment_flag -class GetItemModel(nn.Module): +class GetItemFromTensorModel(nn.Module): def __init__(self): super().__init__() @@ -21,8 +23,8 @@ class GetItemModel(nn.Module): return x -def test_getitem_function_handler(): - model = GetItemModel() +def test_getitem_from_tensor_handler(): + model = GetItemFromTensorModel() tracer = ColoTracer() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -83,5 +85,83 @@ def test_getitem_function_handler(): assert len(getitem_strategies_vector) == len(conv_strategies_vector) +class GetItemFromTupleModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input): + split_node = torch.split(input, 2, 0) + x = split_node[1] + return x + + +def test_getitem_from_tuple_handler(): + model = GetItemFromTupleModel() + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %split : [#users=1] = call_function[target=torch.functional.split](args = (%conv2d, 2), kwargs = {dim: 0}) + # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {}) + # return getitem + graph = tracer.trace(model, meta_args={ + "input": torch.rand(4, 4, 64, 64).to('meta'), + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + input_node = list(graph.nodes)[0] + split_node = list(graph.nodes)[1] + getitem_node = list(graph.nodes)[2] + input_strategies_vector = StrategiesVector(input_node) + getitem_strategies_vector = StrategiesVector(getitem_node) + split_strategies_vector = StrategiesVector(split_node) + + # build handler + input_handler = PlacehodlerHandler( + node=input_node, + device_mesh=device_mesh, + strategies_vector=input_strategies_vector, + placeholder_option='replicated', + ) + input_handler.register_strategy(compute_resharding_cost=False) + setattr(input_node, 'strategies_vector', input_strategies_vector) + split_handler = ReshapeHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector) + split_handler.register_strategy(compute_resharding_cost=False) + setattr(split_node, 'strategies_vector', split_strategies_vector) + getitem_handler = GetItemHandler(node=getitem_node, + device_mesh=device_mesh, + strategies_vector=getitem_strategies_vector) + getitem_handler.register_strategy(compute_resharding_cost=False) + setattr(getitem_node, 'strategies_vector', getitem_strategies_vector) + + # check operation data mapping + mapping = getitem_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "split" + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == (torch.Size([2, 4, 64, 64]), torch.Size([2, 4, 64, 64])) + + assert mapping['index'].name == "index" + assert isinstance(mapping['index'].data, int) + assert mapping['index'].type == OperationDataType.ARG + + assert mapping['output'].name == "getitem" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([2, 4, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(getitem_strategies_vector) == len(split_strategies_vector) + + if __name__ == '__main__': - test_getitem_function_handler() + test_getitem_from_tensor_handler() + test_getitem_from_tuple_handler()