diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 24c2e3758..d58f95a36 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -37,30 +37,6 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]): origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name( str(node)) - # experimental pass for torch.Tensor.view - # Arguments of view op will be divided in the sharded dimensions. - for node in nodes: - if node.op == 'call_method' and getattr(node.args[0]._meta_data.__class__, node.target) in (torch.Tensor.view,): - output_dim_partition_dict = node.sharding_spec.dim_partition_dict - device_mesh = node.sharding_spec.device_mesh - new_args = [] - for arg in node.args: - if isinstance(arg, Node): - if isinstance(arg._meta_data, int): - new_args.append(arg._meta_data) - else: - new_args.append(arg) - else: - assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.' - new_args.append(arg) - - for dim, shard_dims in output_dim_partition_dict.items(): - total_shard_size = 1 - for shard_dim in shard_dims: - total_shard_size *= device_mesh.shape[shard_dim] - new_args[dim + 1] //= total_shard_size - node.args = tuple(new_args) - # the dict to get input sharding specs of user node sharding_spec_convert_dict = {} # the dict to record comm actions of nodes @@ -113,7 +89,74 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]): return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict -def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh): +def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): + """ + This pass will process node args to adapt the distributed tensor layout. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + + for node in nodes: + # skip the placeholder node added in _solution_annotation pass + if not hasattr(node, 'sharding_spec'): + continue + output_dim_partition_dict = node.sharding_spec.dim_partition_dict + device_mesh = node.sharding_spec.device_mesh + new_args = [] + + if node.op == 'call_method': + method = getattr(node.args[0]._meta_data.__class__, node.target) + # process the node with (input, *shape) style args + if method in (torch.Tensor.view, torch.Tensor.reshape): + for arg in node.args: + if isinstance(arg, Node): + if isinstance(arg._meta_data, int): + new_args.append(arg._meta_data) + else: + new_args.append(arg) + else: + assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.' + new_args.append(arg) + + for dim, shard_dims in output_dim_partition_dict.items(): + # we will skip the dim with -1 value + if new_args[dim + 1] == -1: + continue + total_shard_size = 1 + for shard_dim in shard_dims: + total_shard_size *= device_mesh.shape[shard_dim] + new_args[dim + 1] //= total_shard_size + node.args = tuple(new_args) + + elif node.op == 'call_function': + target = node.target + # process the node with (input, torch.Size) style args + if target in (torch.reshape,): + for arg in node.args: + if isinstance(arg, Node): + if isinstance(arg._meta_data, (tuple, list)): + new_args.append(list(arg._meta_data)) + else: + new_args.append(arg) + else: + assert isinstance( + arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.' + new_args.append(list(arg)) + + for dim, shard_dims in output_dim_partition_dict.items(): + # we will skip the dim with -1 value + if new_args[1][dim] == -1: + continue + total_shard_size = 1 + for shard_dim in shard_dims: + total_shard_size *= device_mesh.shape[shard_dim] + new_args[1][dim] //= total_shard_size + node.args = tuple(new_args) + + return gm + + +def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): """ Apply the sharding action to the module parameters and buffers following the instructions of solver solution. @@ -216,6 +259,7 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule): def runtime_preparation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh: DeviceMesh): gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation( gm, solution) + gm = _node_args_converting(gm, device_mesh) # TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed. # gm = implicit_comm_action_apply(gm) gm = _module_params_sharding(gm, device_mesh) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index ab0063dd1..5aff06c6a 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -3,6 +3,7 @@ from .batch_norm_handler import BatchNormModuleHandler from .binary_elementwise_handler import BinaryElementwiseHandler from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler +from .experimental import PermuteHandler, ViewHandler from .getatrr_handler import GetattrHandler from .getitem_handler import GetItemHandler from .layer_norm_handler import LayerNormModuleHandler @@ -21,5 +22,5 @@ __all__ = [ 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', - 'GetItemHandler', 'GetattrHandler' + 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py index 7f644c0e1..225206419 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py @@ -1,4 +1,8 @@ -from .view_generator import ViewGenerator +from .permute_handler import PermuteHandler +from .reshape_generator import PermuteGenerator, TransposeGenerator, ViewGenerator +from .transpose_handler import TransposeHandler from .view_handler import ViewHandler -__all__ = ['ViewGenerator', 'ViewHandler'] +__all__ = [ + 'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator' +] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py new file mode 100644 index 000000000..6d625e153 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py @@ -0,0 +1,76 @@ +from typing import Dict, List + +import torch + +from ...sharding_strategy import OperationData, OperationDataType +from ..node_handler import NodeHandler +from ..registry import operator_registry +from ..strategy import StrategyGenerator +from .reshape_generator import PermuteGenerator + +__all__ = ['PermuteHandler'] + + +@operator_registry.register(torch.Tensor.permute) +@operator_registry.register(torch.permute) +class PermuteHandler(NodeHandler): + """ + A PermuteHandler which deals with the sharding strategies for torch.permute or torch.transpose. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(PermuteGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + permute_dims = [] + if self.node.op == 'call_method': + # torch.Tensor.permute (input, *dims) + for arg in self.node.args: + if isinstance(arg, torch.fx.Node): + if isinstance(arg._meta_data, int): + permute_dims.append(arg._meta_data) + else: + assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.' + permute_dims.append(arg) + else: + # torch.permute (input, dims) + for arg in self.node.args: + if isinstance(arg, torch.fx.Node): + if isinstance(arg._meta_data, (tuple, list)): + permute_dims.extend(arg._meta_data) + else: + assert isinstance( + arg, + (tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].' + permute_dims.extend(arg) + + num_dims = self.node._meta_data.dim() + for i in range(num_dims): + # recover negative value to positive + if permute_dims[i] < 0: + permute_dims[i] += num_dims + + physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims)) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "permute_dims": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py similarity index 59% rename from colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_generator.py rename to colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py index 21439fac0..1d1be2c5e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py @@ -17,12 +17,12 @@ from colossalai.auto_parallel.tensor_shard.utils import ( from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.sharding_spec import ShardingSpec -__all__ = ['ViewGenerator'] +__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator'] -class ViewGenerator(FollowingStrategyGenerator): +class ReshapeGenerator(FollowingStrategyGenerator): """ - ViewGenerator which deals with the sharding strategies of view op. + ReshapeGenerator is the base class for all the reshape operation. """ def validate(self) -> bool: @@ -61,6 +61,15 @@ class ViewGenerator(FollowingStrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost + def collate_strategies(self) -> List[ShardingStrategy]: + return super().collate_strategies() + + +class ViewGenerator(ReshapeGenerator): + """ + ViewGenerator deals with the sharding strategies of view op. + """ + def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] for index, strategy in enumerate(self.predecessor_node.strategies_vector): @@ -136,3 +145,85 @@ class ViewGenerator(FollowingStrategyGenerator): strategy_list.append(strategy) return strategy_list + + +class PermuteGenerator(ReshapeGenerator): + """ + PermuteGenerator deals with the sharding strategies of permute op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + + permute_dims = self.op_data['permute_dims'].data + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + dim_partition_dict_for_output = {} + for dim_index, permute_dim in enumerate(permute_dims): + if permute_dim in dim_partition_dict_for_input: + dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim] + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list + + +class TransposeGenerator(ReshapeGenerator): + """ + TransposeGenerator deals with the sharding strategies of permute op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + dim_partition_dict_for_output = {} + + transpose_dims = self.op_data['transpose_dims'].data + dim_0 = transpose_dims[0] + dim_1 = transpose_dims[1] + for dim, sharded_dims in dim_partition_dict_for_input.items(): + if dim == dim_0: + dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0] + elif dim == dim_1: + dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1] + else: + dim_partition_dict_for_output[dim] = sharded_dims + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py new file mode 100644 index 000000000..3c7336a93 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py @@ -0,0 +1,65 @@ +from typing import Dict, List + +import torch + +from ...sharding_strategy import OperationData, OperationDataType +from ..node_handler import NodeHandler +from ..registry import operator_registry +from ..strategy import StrategyGenerator +from .reshape_generator import TransposeGenerator + +__all__ = ['TransposeHandler'] + + +@operator_registry.register(torch.Tensor.transpose) +@operator_registry.register(torch.transpose) +class TransposeHandler(NodeHandler): + """ + A TransposeHandler which deals with the sharding strategies for torch.permute or torch.transpose. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(TransposeGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + transpose_dims = [] + # torch.transpose (input, dim0, dim1) + for arg in self.node.args: + if isinstance(arg, torch.fx.Node): + if isinstance(arg._meta_data, int): + transpose_dims.append(arg._meta_data) + else: + transpose_dims.append(arg) + + num_dims = self.node._meta_data.dim() + for i in range(2): + # recover negative value to positive + if transpose_dims[i] < 0: + transpose_dims[i] += num_dims + + physical_shape_operand = OperationData(name='transpose_dims', + type=OperationDataType.ARG, + data=list(transpose_dims)) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "transpose_dims": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py index bab4e0d76..6be634593 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py @@ -6,11 +6,13 @@ from ...sharding_strategy import OperationData, OperationDataType from ..node_handler import NodeHandler from ..registry import operator_registry from ..strategy import StrategyGenerator -from .view_generator import ViewGenerator +from .reshape_generator import ViewGenerator __all__ = ['ViewHandler'] +@operator_registry.register(torch.Tensor.reshape) +@operator_registry.register(torch.reshape) @operator_registry.register(torch.Tensor.view) class ViewHandler(NodeHandler): """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py index 43ea265d7..82319c52d 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py @@ -10,13 +10,9 @@ from .strategy import ReshapeGenerator, StrategyGenerator __all__ = ['ReshapeHandler'] -@operator_registry.register(torch.reshape) @operator_registry.register(torch.Tensor.split) @operator_registry.register(torch.split) @operator_registry.register(torch.flatten) -@operator_registry.register(torch.Tensor.transpose) -@operator_registry.register(torch.Tensor.permute) -@operator_registry.register(torch.Tensor.view) @operator_registry.register(torch.nn.AdaptiveAvgPool2d) class ReshapeHandler(NodeHandler): """ diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index c8539d38d..3c9e0fd56 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -23,6 +23,8 @@ def _all_gather(tensor, comm_spec): torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) ] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() dist.all_gather(tensor_list, tensor, group=process_group) output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() return output diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py new file mode 100644 index 000000000..c695b8843 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py @@ -0,0 +1,339 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.experimental import PermuteHandler, TransposeHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class ConvReshapeModel(nn.Module): + + def __init__(self, reshape_dims, call_function): + super().__init__() + self.reshape_dims = reshape_dims + self.call_function = call_function + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other, bias=None) + # permute_node = torch.permute(conv_node, self.permute_dims) + if self.call_function == torch.permute: + permute_node = self.call_function(conv_node, self.reshape_dims) + else: + permute_node = self.call_function(conv_node, *self.reshape_dims) + return permute_node + + +class LinearReshapeModel(nn.Module): + + def __init__(self, reshape_dims, call_function): + super().__init__() + self.reshape_dims = reshape_dims + self.call_function = call_function + + def forward(self, input, other): + linear_node = nn.functional.linear(input, other, bias=None) + # permute_node = torch.permute(linear_node, self.tgt_shape) + if self.call_function == torch.permute: + permute_node = self.call_function(linear_node, self.reshape_dims) + else: + permute_node = self.call_function(linear_node, *self.reshape_dims) + return permute_node + + +def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + if call_function == torch.permute: + reshape_dims = reshape_dims[0] + elif call_function == torch.transpose: + reshape_dims = reshape_dims[1] + model = model_cls(reshape_dims, call_function).cuda() + + if model_cls.__name__ == 'ConvReshapeModel': + input = torch.rand(8, 8, 66, 66).to('cuda') + other = torch.rand(16, 8, 3, 3).to('cuda') + # index of conv node in computation graph + node_index = 2 + # total number of conv strategies + strategy_number = 16 + if model_cls.__name__ == 'LinearReshapeModel': + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer() + if model_cls.__name__ == 'ConvReshapeModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {bias: None}) + # %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {}) + # return permute + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 8, 66, 66).to('meta'), + "other": torch.rand(16, 8, 3, 3).to('meta'), + }) + + if model_cls.__name__ == 'LinearReshapeModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) + # return permute + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), + }) + + gm = ColoGraphModule(model, graph) + + previous_mod_node = list(graph.nodes)[2] + reshape_node = list(graph.nodes)[3] + view_strategies_vector = StrategiesVector(reshape_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + if model_cls.__name__ == 'ConvReshapeModel': + + conv_handler = ConvFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + if model_cls.__name__ == 'LinearReshapeModel': + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + if call_function == torch.permute: + reshape_handler = PermuteHandler(node=reshape_node, + device_mesh=device_mesh, + strategies_vector=view_strategies_vector) + else: + reshape_handler = TransposeHandler(node=reshape_node, + device_mesh=device_mesh, + strategies_vector=view_strategies_vector) + + reshape_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = reshape_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + if model_cls.__name__ == 'ConvReshapeModel': + assert mapping['input'].name == "conv2d" + else: + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + if call_function == torch.permute: + assert mapping['output'].name == "permute" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.permute(torch.rand(8, 16, 64, 64), reshape_dims).shape + assert mapping['output'].type == OperationDataType.OUTPUT + else: + assert mapping['output'].name == "transpose" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.transpose(torch.rand(8, 16, 64, 64), *reshape_dims).shape + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(view_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in view_strategies_vector] + if rank == 0: + for name in strategy_name_list: + print(name) + if model_cls.__name__ == 'ConvReshapeModel': + + if reshape_dims in ((0, 2, 1, 3), (1, 2)): + assert '[S0, S1, R, R] -> [S0, R, S1, R]_0' in strategy_name_list + assert '[S1, S0, R, R] -> [S1, R, S0, R]_1' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list + + if reshape_dims == (2, 0, 1, 3): + assert '[S0, S1, R, R] -> [R, S0, S1, R]_0' in strategy_name_list + assert '[S1, S0, R, R] -> [R, S1, S0, R]_1' in strategy_name_list + assert '[S0, R, R, R] -> [R, S0, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [R, S1, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [R, S0, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [R, S1, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [R, S01, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list + + if reshape_dims == (1, 3): + assert '[S0, S1, R, R] -> [S0, R, R, S1]_0' in strategy_name_list + assert '[S1, S0, R, R] -> [S1, R, R, S0]_1' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, R, S1]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, R, S0]_10' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, R, S1]_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, R, S01]_15' in strategy_name_list + + if model_cls.__name__ == 'LinearReshapeModel': + + if reshape_dims == ((0, 2, 1, 3), (1, 2)): + assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S0, R, S1]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S1, R, S0]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, S0, R, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, S1, R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, S01, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + + if reshape_dims == (2, 0, 1, 3): + assert '[S0, R, R, S1] -> [R, S0, R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [S0, R, R, S1]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [R, S1, R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [S1, R, R, S0]_5' in strategy_name_list + assert '[S0, R, R, R] -> [R, S0, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [S0, R, R, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [R, S1, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [S1, R, R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R] -> [R, S01, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [S01, R, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + + if reshape_dims == (1, 3): + assert '[S0, R, R, S1] -> [S0, S1, R, R]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S1, R, S0]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S1, S0, R]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, S0, R, R]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S0, R, S1]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S0, S1, R]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1, R, R]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0, R, R]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1, R, R]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, R, S01]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, S01, R, R]_22' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('call_function', [torch.permute, torch.transpose]) +@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))]) +@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel]) +def test_view_handler(call_function, reshape_dims, model_cls): + world_size = 4 + run_func = partial(check_view_handler, + call_function=call_function, + reshape_dims=reshape_dims, + model_cls=model_cls, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_view_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py index 16f9fa63d..08a702789 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py @@ -84,7 +84,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): # return view graph = tracer.trace(model, meta_args={ - "input": torch.rand(8, 16, 66, 66).to('meta'), + "input": torch.rand(8, 8, 66, 66).to('meta'), "other": torch.rand(16, 8, 3, 3).to('meta'), })