diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index c762bdca7..92916118b 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -1,5 +1,6 @@ +import operator from copy import deepcopy -from typing import List +from typing import Dict, List, Union import torch from torch.fx import symbolic_trace @@ -20,6 +21,35 @@ from colossalai.tensor.sharding_spec import ShardingSpec shape_consistency_manager = ShapeConsistencyManager() +def size_processing(size: Union[int, torch.Size], + dim_partition_dict: Dict[int, List[int]], + device_mesh_info: Dict[int, int], + target_dim: int = None, + node_name: str = None): + """ + This method will be invoked during runtime to convert size node value depending on distributed information. + """ + if target_dim is not None: + assert isinstance(size, int) + if target_dim in dim_partition_dict: + total_shard_size = 1 + for shard_dim in dim_partition_dict[target_dim]: + total_shard_size *= device_mesh_info[shard_dim] + size = size * total_shard_size + + else: + size = list(size) + for dim, dim_size in enumerate(size): + if dim in dim_partition_dict: + total_shard_size = 1 + for shard_dim in dim_partition_dict[dim]: + total_shard_size *= device_mesh_info[shard_dim] + size[dim] = dim_size * total_shard_size + size = torch.Size(size) + + return size + + def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int], strategies_constructor: StrategiesConstructor = None): @@ -103,6 +133,119 @@ def _solution_annotatation(gm: torch.fx.GraphModule, return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict +def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): + """ + In the auto parallel system, tensors may get shard on different devices, so the size of tensors + need to be converted to the size of original tensor and managed by the users, such as torch.view, + torch.reshape, etc. These nodes have enough information like input sharding_spec and + output sharding_spec to decide how to convert the size value. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + node_pairs = {} + + for node in nodes: + + if node.op == 'call_method' and node.target == 'size': + # extract useful information from size node + # dim_partition_dict will instruct the size value on which + # dimension should be enlarged. + sharding_spec = node.args[0].sharding_spec + dim_partition_dict = sharding_spec.dim_partition_dict + + # there are two usages of torch.Tensor.size: + # tensor.size() + # tensor.size(dim) + # if a target_dim is assigned, then the output will be + # in type of int, instead of torch.Size + target_dim = None + if len(node.args) > 1: + target_dim = node.args[1] + if target_dim < 0: + target_dim += node.args[0]._meta_data.dim() + + # DeviceMesh information instructs the scaling of the size value + device_mesh_info = {} + for dim, dim_size in enumerate(device_mesh.mesh_shape): + device_mesh_info[dim] = dim_size + + with mod_graph.inserting_after(node): + size_processing_node = mod_graph.create_node('call_function', + size_processing, + args=(node, dim_partition_dict, device_mesh_info, + target_dim, node.name)) + # store original node and processing node pair in node_pairs dictioanry + # It will be used to replace the original node with processing node in slice object + node_pairs[node] = size_processing_node + size_processing_node._meta_data = node._meta_data + + user_list = list(node.users.keys()) + for user in user_list: + if user == size_processing_node: + continue + new_args = list(user.args) + new_kwargs = dict(user.kwargs) + # the origin node may be a positional argument or key word argument of user node + if node in new_args: + # substitute the origin node with size_processing_node + new_args[new_args.index(node)] = size_processing_node + user.args = tuple(new_args) + elif str(node) in new_kwargs: + # substitute the origin node with size_processing_node + new_kwargs[str(node)] = size_processing_node + user.kwargs = new_kwargs + + if node.op == 'call_function' and node.target == operator.getitem: + + getitem_index = node.args[1] + # slice object is quite special in torch.fx graph, + # On one side, we treat slice object same as type of int, + # so we do not create a node for slice object. On the other side, + # slice object could take fx.Node as its argument. And the user + # relationship cannot be tracked in fx graph. + # Therefore, I record the node_pairs in this pass, and use the it + # to replace the original node argument inside the slice object if + # it has been processed in above pass. + + # There are three main usages of operator.getitem: + # getitem(input, int) + # getitem(input, slice) + # getitem(input, Tuple[slice]) + # In this pass, we need process the last two cases because + # node arguments may potentially appear in these cases. + if isinstance(getitem_index, slice): + new_start, new_stop, new_step = getitem_index.start, getitem_index.stop, getitem_index.step + if getitem_index.start in node_pairs: + new_start = node_pairs[getitem_index.start] + elif getitem_index.stop in node_pairs: + new_stop = node_pairs[getitem_index.stop] + elif getitem_index.step in node_pairs: + new_step = node_pairs[getitem_index.step] + new_slice_item = slice(new_start, new_stop, new_step) + new_args = (node.args[0], new_slice_item) + node.args = new_args + + elif isinstance(getitem_index, (tuple, list)): + assert isinstance(getitem_index[0], slice) + new_slice_items = [] + + for slice_item in getitem_index: + new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step + if slice_item.start in node_pairs: + new_start = node_pairs[slice_item.start] + elif slice_item.stop in node_pairs: + new_stop = node_pairs[slice_item.stop] + elif slice_item.step in node_pairs: + new_step = node_pairs[slice_item.step] + new_slice_item = slice(new_start, new_stop, new_step) + new_slice_items.append(new_slice_item) + + new_args = (node.args[0], tuple(new_slice_items)) + node.args = new_args + + return gm + + def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): """ This pass will process node args to adapt the distributed tensor layout. @@ -138,6 +281,7 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): method = getattr(node.args[0]._meta_data.__class__, node.target) # process the node with (input, *shape) style args if method in (torch.Tensor.view, torch.Tensor.reshape): + for arg in node.args: if isinstance(arg, Node): if isinstance(arg._meta_data, (int, tuple, list)): @@ -157,10 +301,18 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): # 1. torch.view(input, *shape) # 2. torch.view(input, shape) if isinstance(new_args[1], int): - new_args[dim + 1] //= total_shard_size + # we will skip the dim with -1 value + if new_args[dim + 1] == -1: + continue + else: + new_args[dim + 1] //= total_shard_size else: new_args[1] = list(new_args[1]) - new_args[1][dim] //= total_shard_size + # we will skip the dim with -1 value + if new_args[1][dim] == -1: + continue + else: + new_args[1][dim] //= total_shard_size node.args = tuple(new_args) elif node.op == 'call_function': @@ -298,6 +450,7 @@ def runtime_preparation_pass(gm: torch.fx.GraphModule, strategies_constructor: StrategiesConstructor = None): gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation( gm, solution, strategies_constructor) + gm = _size_value_converting(gm, device_mesh) gm = _node_args_converting(gm, device_mesh) # TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed. # gm = implicit_comm_action_apply(gm) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py index 659edf548..d8e3ce6a5 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -28,8 +28,9 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr # switch the dimensions of the transposed weight sharding_spec = strategy.get_sharding_spec_by_name(weight_name) op_data = strategy.get_op_data_by_name(weight_name) - assert op_data.logical_shape != op_data.data.shape, \ - "Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same" + assert op_data.logical_shape[0] == op_data.data.shape[1] and \ + op_data.logical_shape[1] == op_data.data.shape[0], \ + "Expected the logical shape of the linear operator's weight is equal to transposed physical shape" dim_size = len(op_data.logical_shape) transpose_partition_dim(sharding_spec, 0, dim_size - 1) return strategy