diff --git a/colossalai/auto_parallel/passes/constants.py b/colossalai/auto_parallel/passes/constants.py index b86088474..485a87492 100644 --- a/colossalai/auto_parallel/passes/constants.py +++ b/colossalai/auto_parallel/passes/constants.py @@ -6,3 +6,8 @@ OUTPUT_SAVED_MOD = [ torch.nn.ReLU, torch.nn.Softmax, ] + +# SHAPE_ARGUMENT_OPS contains node with (input, *shape) style args. +# This list could be extended if any other method has the same +# argument style as view and reshape. +SHAPE_ARGUMENT_OPS = [torch.Tensor.view, torch.Tensor.reshape, torch.reshape] diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index ecf3f1f18..bb419be35 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -19,6 +19,8 @@ from colossalai.tensor.comm_spec import _all_reduce from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec +from .constants import SHAPE_ARGUMENT_OPS + shape_consistency_manager = ShapeConsistencyManager() @@ -51,23 +53,16 @@ def size_processing(size: Union[int, torch.Size], return size -def _solution_annotatation(gm: torch.fx.GraphModule, - solution: List[int], - strategies_constructor: StrategiesConstructor = None): +def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], + strategies_constructor: StrategiesConstructor): """ This method is used to stick the solution strategy to the nodes and add the information required in runtime into graph as placeholder nodes. """ mod_graph = gm.graph - # TODO: In future PR, strategies_constructor should be a required argument, - # instead of optional argument. This is because we don't need to consider nodes with - # no strategy in runtime preparation pass. - if strategies_constructor is not None: - nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] - no_strategy_nodes = strategies_constructor.no_strategy_nodes - else: - nodes = tuple(mod_graph.nodes) - no_strategy_nodes = [] + + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + no_strategy_nodes = strategies_constructor.no_strategy_nodes # the dict to get origin sharding spec of node origin_node_sharding_spec_dict = {} @@ -97,6 +92,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule, target_sharding_specs.append(target_sharding_spec) sharding_spec_convert_dict[index] = target_sharding_specs setattr(node, 'target_sharding_specs', target_sharding_specs) + # the get_attr node strategy is kind of pending strategy, which means we will change it # to the same strategy of the user node. if node.op == 'get_attr': @@ -134,7 +130,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule, return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict -def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): +def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): """ In the auto parallel system, tensors may get shard on different devices, so the size of tensors need to be converted to the size of original tensor and managed by the users, such as torch.view, @@ -145,6 +141,80 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): nodes = tuple(mod_graph.nodes) node_pairs = {} + # DeviceMesh information instructs the scaling of the size value + device_mesh_info = {} + for dim, dim_size in enumerate(device_mesh.mesh_shape): + device_mesh_info[dim] = dim_size + + def _extract_target_dim(node): + ''' + A helper function to etract the target dimension from size node. + There are two usages of torch.Tensor.size: + 1. tensor.size() + 2. tensor.size(dim) + + If a target_dim is assigned, then the output will be in type of int, instead of torch.Size. + Otherwise, the output will be in type of torch.Size and this function will return None. + ''' + target_dim = None + if len(node.args) > 1: + target_dim = node.args[1] + if target_dim < 0: + target_dim += node.args[0]._meta_data.dim() + return target_dim + + def _post_processing(node, size_processing_node): + ''' + This function is used to process the dependency between the size node and its users after + inserting the size_process_node. + ''' + # store original node and processing node pair in node_pairs dictioanry + # It will be used to replace the original node with processing node in slice object + node_pairs[node] = size_processing_node + size_processing_node._meta_data = node._meta_data + if 'activation_checkpoint' in node.meta: + size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint'] + + user_list = list(node.users.keys()) + for user in user_list: + if user == size_processing_node: + continue + new_args = list(user.args) + new_kwargs = dict(user.kwargs) + # the origin node may be a positional argument or key word argument of user node + if node in new_args: + # substitute the origin node with size_processing_node + new_args[new_args.index(node)] = size_processing_node + user.args = tuple(new_args) + elif str(node) in new_kwargs: + # substitute the origin node with size_processing_node + new_kwargs[str(node)] = size_processing_node + user.kwargs = new_kwargs + + def _update_slice_object_args(slice_object): + ''' + This function is used to update the slice object argument list. + If the slice object contains the Node argument, then the size node will be replaced with + ''' + if isinstance(slice_object, slice): + start = slice_object.start + stop = slice_object.stop + step = slice_object.step + if start in node_pairs: + start = node_pairs[start] + if stop in node_pairs: + stop = node_pairs[stop] + if step in node_pairs: + step = node_pairs[step] + return slice(start, stop, step) + elif isinstance(slice_object, int): + if slice_object in node_pairs: + return node_pairs[slice_object] + else: + return slice_object + else: + raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}") + for node in nodes: if node.op == 'call_method' and node.target == 'size': @@ -154,49 +224,15 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): sharding_spec = node.args[0].sharding_spec dim_partition_dict = sharding_spec.dim_partition_dict - # there are two usages of torch.Tensor.size: - # tensor.size() - # tensor.size(dim) - # if a target_dim is assigned, then the output will be - # in type of int, instead of torch.Size - target_dim = None - if len(node.args) > 1: - target_dim = node.args[1] - if target_dim < 0: - target_dim += node.args[0]._meta_data.dim() - - # DeviceMesh information instructs the scaling of the size value - device_mesh_info = {} - for dim, dim_size in enumerate(device_mesh.mesh_shape): - device_mesh_info[dim] = dim_size + target_dim = _extract_target_dim(node) + # insert size_processing node with mod_graph.inserting_after(node): size_processing_node = mod_graph.create_node('call_function', size_processing, args=(node, dim_partition_dict, device_mesh_info, target_dim, node.name)) - # store original node and processing node pair in node_pairs dictioanry - # It will be used to replace the original node with processing node in slice object - node_pairs[node] = size_processing_node - size_processing_node._meta_data = node._meta_data - if 'activation_checkpoint' in node.meta: - size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint'] - - user_list = list(node.users.keys()) - for user in user_list: - if user == size_processing_node: - continue - new_args = list(user.args) - new_kwargs = dict(user.kwargs) - # the origin node may be a positional argument or key word argument of user node - if node in new_args: - # substitute the origin node with size_processing_node - new_args[new_args.index(node)] = size_processing_node - user.args = tuple(new_args) - elif str(node) in new_kwargs: - # substitute the origin node with size_processing_node - new_kwargs[str(node)] = size_processing_node - user.kwargs = new_kwargs + _post_processing(node, size_processing_node) if node.op == 'call_function' and node.target == operator.getitem: @@ -217,14 +253,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): # In this pass, we need process the last two cases because # node arguments may potentially appear in these cases. if isinstance(getitem_index, slice): - new_start, new_stop, new_step = getitem_index.start, getitem_index.stop, getitem_index.step - if getitem_index.start in node_pairs: - new_start = node_pairs[getitem_index.start] - elif getitem_index.stop in node_pairs: - new_stop = node_pairs[getitem_index.stop] - elif getitem_index.step in node_pairs: - new_step = node_pairs[getitem_index.step] - new_slice_item = slice(new_start, new_stop, new_step) + new_slice_item = _update_slice_object_args(getitem_index) new_args = (node.args[0], new_slice_item) node.args = new_args @@ -237,16 +266,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): if slice_item is None: new_slice_items.append(None) continue - - new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step - - if slice_item.start in node_pairs: - new_start = node_pairs[slice_item.start] - elif slice_item.stop in node_pairs: - new_stop = node_pairs[slice_item.stop] - elif slice_item.step in node_pairs: - new_step = node_pairs[slice_item.step] - new_slice_item = slice(new_start, new_stop, new_step) + new_slice_item = _update_slice_object_args(slice_item) new_slice_items.append(new_slice_item) new_args = (node.args[0], tuple(new_slice_items)) @@ -255,104 +275,109 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): return gm -def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): +def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): """ This pass will process node args to adapt the distributed tensor layout. """ mod_graph = gm.graph nodes = tuple(mod_graph.nodes) + def _extract_info_from_sharding_spec(sharding_spec): + ''' + This function is used to extract the dim_partition_dict and device_mesh from + sharding spec instance or a list of sharding spec. + ''' + if isinstance(sharding_spec, ShardingSpec): + dim_partition_dict = sharding_spec.dim_partition_dict + device_mesh = sharding_spec.device_mesh + return dim_partition_dict, device_mesh + if sharding_spec is None: + return None, None + assert isinstance(sharding_spec, + (tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None' + + device_mesh = sharding_spec[0].device_mesh + dim_partition_dict = [] + for element in sharding_spec: + dim_partition_dict.append(_extract_info_from_sharding_spec(element)) + return dim_partition_dict, sharding_spec + + def _process_node_arguments(node): + new_args = [] + for arg in node.args: + # There are two args style: + # 1. (input, *shape) + # 2. (input, shape) + # We will extract the elements from shape and add them into the new_args + # Finally, the args style of new_args will be unified to (input, *shape) + if isinstance(arg, Node): + if isinstance(arg._meta_data, (tuple, list)): + new_args.extend(arg._meta_data) + elif isinstance(arg._meta_data, int): + new_args.append(arg._meta_data) + else: + new_args.append(arg) + else: + assert isinstance(arg, + (int, tuple, list)), 'The argument in view node should be either type of Node or int.' + if isinstance(arg, (tuple, list)): + new_args.extend(arg) + else: + new_args.append(arg) + return new_args + + def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node): + new_args = _process_node_arguments(node) + if node.op == 'call_method': + args_to_process = list(new_args[1:]) + else: + args_to_process = list(new_args) + for dim, shard_dims in dim_partition_dict.items(): + total_shard_size = 1 + for shard_dim in shard_dims: + total_shard_size *= device_mesh.shape[shard_dim] + + # we will skip the dim with -1 value + if args_to_process[dim] == -1: + continue + else: + # TODO: add assertion here to make sure the dim size is divisible by total_shard_size + args_to_process[dim] //= total_shard_size + + args_to_process = tuple(args_to_process) + + if node.op == 'call_method': + new_args = (new_args[0],) + args_to_process + else: + new_args = args_to_process + + node.args = new_args + + def _filter_node_with_shape_args(node): + if node.op == 'call_method': + target = getattr(node.args[0]._meta_data.__class__, node.target) + elif node.op == 'call_function': + target = node.target + else: + target = None + + if target in SHAPE_ARGUMENT_OPS: + return True + return False + for node in nodes: # skip the placeholder node added in _solution_annotation pass if not hasattr(node, 'sharding_spec'): continue - def _process_sharding_spec(sharding_spec): - if isinstance(sharding_spec, ShardingSpec): - dim_partition_dict = sharding_spec.dim_partition_dict - device_mesh = sharding_spec.device_mesh - return dim_partition_dict, device_mesh - if sharding_spec is None: - return None, None - assert isinstance(sharding_spec, - (tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None' - - device_mesh = sharding_spec[0].device_mesh - dim_partition_dict = [] - for element in sharding_spec: - dim_partition_dict.append(_process_sharding_spec(element)) - return dim_partition_dict, sharding_spec - - output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec) - new_args = [] - - if node.op == 'call_method': - method = getattr(node.args[0]._meta_data.__class__, node.target) - # process the node with (input, *shape) style args - if method in (torch.Tensor.view, torch.Tensor.reshape): - - for arg in node.args: - if isinstance(arg, Node): - if isinstance(arg._meta_data, (int, tuple, list)): - new_args.append(arg._meta_data) - else: - new_args.append(arg) - else: - assert isinstance( - arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.' - new_args.append(arg) - - for dim, shard_dims in output_dim_partition_dict.items(): - total_shard_size = 1 - for shard_dim in shard_dims: - total_shard_size *= device_mesh.shape[shard_dim] - # There are two ways to use torch.view: - # 1. torch.view(input, *shape) - # 2. torch.view(input, shape) - if isinstance(new_args[1], int): - # we will skip the dim with -1 value - if new_args[dim + 1] == -1: - continue - else: - new_args[dim + 1] //= total_shard_size - else: - new_args[1] = list(new_args[1]) - # we will skip the dim with -1 value - if new_args[1][dim] == -1: - continue - else: - new_args[1][dim] //= total_shard_size - node.args = tuple(new_args) - - elif node.op == 'call_function': - target = node.target - # process the node with (input, torch.Size) style args - if target in (torch.reshape,): - for arg in node.args: - if isinstance(arg, Node): - if isinstance(arg._meta_data, (tuple, list)): - new_args.append(list(arg._meta_data)) - else: - new_args.append(arg) - else: - assert isinstance( - arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.' - new_args.append(list(arg)) - - for dim, shard_dims in output_dim_partition_dict.items(): - # we will skip the dim with -1 value - if new_args[1][dim] == -1: - continue - total_shard_size = 1 - for shard_dim in shard_dims: - total_shard_size *= device_mesh.shape[shard_dim] - new_args[1][dim] //= total_shard_size - node.args = tuple(new_args) + output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec) + if _filter_node_with_shape_args(node): + _scale_args_adapt_sharding_spec(output_dim_partition_dict, device_mesh, node) return gm -def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False): +def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False): """ Apply the sharding action to the module parameters and buffers following the instructions of solver solution. @@ -361,6 +386,49 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o nodes = tuple(mod_graph.nodes) # This stream is created for overlaping the communication and computation. reduction_stream = torch.cuda.Stream() + + def _add_hook_for_grad_communication(node, param): + + comm_actions = node.best_strategy.communication_actions + + def _filter_param_to_hook(node, op_data, comm_action): + if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == param.name and comm_action.comm_type == CommType.HOOK: + return True + if node.op == 'get_attr' and isinstance( + node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: + return True + return False + + for operation_data, comm_action in comm_actions.items(): + comm_spec_to_use = comm_action.comm_spec + # register hook to the parameters + if _filter_param_to_hook(node, operation_data, comm_action): + + def wrapper(param, comm_spec, stream, overlap): + + def hook_fn(grad): + if overlap: + with torch.cuda.stream(stream): + _all_reduce(grad, comm_spec, async_op=True) + else: + _all_reduce(grad, comm_spec, async_op=False) + + param.register_hook(hook_fn) + + wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap) + + def _shard_param(param, target_sharding_spec): + # apply the sharding spec of parameters + if target_sharding_spec.dim_partition_dict != {}: + origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {}) + setattr(param, 'sharding_spec', origin_sharding_spec) + # TODO: build a ColoParamter class to manager the distributed parameters + # we could use .data here, because all the operations just happen before the real training + # loop, so we don't need to track these operations in the autograd graph. + param = torch.nn.Parameter( + shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, + target_sharding_spec).detach().clone()) + for node in nodes: if node.op == 'call_module': target_module = node.graph.owning_module.get_submodule(node.target) @@ -370,36 +438,10 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o setattr(target_module, 'processed', True) for name, param in target_module.named_parameters(): target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) - # apply the sharding spec of parameters - if target_sharding_spec.dim_partition_dict != {}: - origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {}) - setattr(param, 'sharding_spec', origin_sharding_spec) - # TODO: build a ColoParamter class to manager the distributed parameters - # we could use .data here, because all the operations just happen before the real training - # loop, so we don't need to track these operations in the autograd graph. - param = torch.nn.Parameter( - shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, - target_sharding_spec).detach().clone()) + _shard_param(param, target_sharding_spec) setattr(target_module, name, param) - comm_actions = node.best_strategy.communication_actions - for operation_data, comm_action in comm_actions.items(): - comm_spec_to_use = comm_action.comm_spec - # register hook to the parameters - if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK: - - def wrapper(param, comm_spec, stream, overlap): - - def hook_fn(grad): - if overlap: - with torch.cuda.stream(stream): - _all_reduce(grad, comm_spec, async_op=True) - else: - _all_reduce(grad, comm_spec, async_op=False) - - param.register_hook(hook_fn) - - wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap) + _add_hook_for_grad_communication(node, param) sharded_buffer_dict = {} # apply the sharding spec of buffers @@ -427,37 +469,12 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o target = getattr(target_module, atoms[-1]) target_sharding_spec = node.sharding_spec - if target_sharding_spec.dim_partition_dict != {}: - origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {}) - setattr(target, 'sharding_spec', origin_sharding_spec) - # TODO: build a ColoParamter class to manager the distributed parameters - # we could use .data here, because all the operations just happen before the real training - # loop, so we don't need to track these operations in the autograd graph. - target = torch.nn.Parameter( - shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec, - target_sharding_spec).detach().clone()) + _shard_param(target, target_sharding_spec) assert hasattr(target_module, atoms[-1]) setattr(target_module, atoms[-1], target) + _add_hook_for_grad_communication(node, target) - comm_actions = node.best_strategy.communication_actions - for operation_data, comm_action in comm_actions.items(): - comm_spec_to_use = comm_action.comm_spec - # register hook to the parameters - if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: - - def wrapper(param, comm_spec, stream, overlap): - - def hook_fn(grad): - if overlap: - with torch.cuda.stream(stream): - _all_reduce(grad, comm_spec, async_op=True) - else: - _all_reduce(grad, comm_spec, async_op=False) - - param.register_hook(hook_fn) - - wrapper(target, comm_spec_to_use, reduction_stream, overlap=overlap) return gm @@ -471,14 +488,14 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule): def runtime_preparation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh: DeviceMesh, - strategies_constructor: StrategiesConstructor = None, + strategies_constructor: StrategiesConstructor, overlap=False): - gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation( + gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass( gm, solution, strategies_constructor) - gm = _size_value_converting(gm, device_mesh) - gm = _node_args_converting(gm, device_mesh) + gm = size_value_converting_pass(gm, device_mesh) + gm = node_args_converting_pass(gm, device_mesh) # TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed. # gm = implicit_comm_action_apply(gm) - gm = _module_params_sharding(gm, device_mesh, overlap=overlap) + gm = module_params_sharding_pass(gm, device_mesh, overlap=overlap) return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict diff --git a/tests/test_auto_parallel/test_pass/test_node_converting_pass.py b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py new file mode 100644 index 000000000..d0d107610 --- /dev/null +++ b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py @@ -0,0 +1,54 @@ +import torch +import torch.nn.functional as F + +from colossalai.auto_parallel.passes.runtime_preparation_pass import node_args_converting_pass +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec + + +class TestModule(torch.nn.Module): + + def forward(self, x): + x = x.view(4, 4, 2) + return x + + +def insert_narrow(gm, x_node): + graph = gm.graph + with graph.inserting_after(x_node): + shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={}) + view_node = list(x_node.users.keys())[0] + new_args = list(view_node.args) + new_args[0] = shard_node + view_node.args = tuple(new_args) + return gm + + +def test_node_args_converting_pass(): + model = TestModule() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + meta_args = {'x': torch.rand(4, 8).to('meta')} + input = torch.rand(4, 8) + tracer = ColoTracer() + graph = tracer.trace(root=model, meta_args=meta_args) + + x_node = list(graph.nodes)[0] + view_node = list(graph.nodes)[1] + sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]}) + setattr(x_node, 'sharding_spec', sharding_spec) + setattr(view_node, 'sharding_spec', sharding_spec) + + gm = ColoGraphModule(model, graph) + gm = node_args_converting_pass(gm, device_mesh) + gm = insert_narrow(gm, x_node) + gm.recompile() + output = gm(input) + assert output.shape == torch.Size([2, 4, 2]) + + +if __name__ == '__main__': + test_node_args_converting_pass() diff --git a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py new file mode 100644 index 000000000..349483008 --- /dev/null +++ b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py @@ -0,0 +1,65 @@ +import torch +import torch.nn.functional as F + +from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec + + +class TestModule(torch.nn.Module): + + def forward(self, x): + size = x.size() + return size + + +def insert_narrow(gm, x_node): + graph = gm.graph + with graph.inserting_after(x_node): + shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={}) + size_node = list(x_node.users.keys())[0] + size_node.args = (shard_node,) + return gm + + +def recover_narrow(gm, narrow_node): + graph = gm.graph + size_node = list(graph.nodes)[2] + x_node = narrow_node.args[0] + size_node.args = (x_node,) + graph.erase_node(narrow_node) + return gm + + +def test_size_value_converting_pass(): + model = TestModule() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + meta_args = {'x': torch.rand(4, 8).to('meta')} + input = torch.rand(4, 8) + tracer = ColoTracer() + graph = tracer.trace(root=model, meta_args=meta_args) + + x_node = list(graph.nodes)[0] + x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]}) + setattr(x_node, 'sharding_spec', x_sharding_spec) + gm = ColoGraphModule(model, graph) + gm = insert_narrow(gm, x_node) + gm.recompile() + size = gm(input) + assert size == torch.Size([2, 8]) + + narrow_node = list(gm.graph.nodes)[1] + gm = recover_narrow(gm, narrow_node) + gm = size_value_converting_pass(gm, device_mesh) + gm = insert_narrow(gm, x_node) + gm.recompile() + size = gm(input) + assert size == torch.Size([4, 8]) + + +if __name__ == '__main__': + test_size_value_converting_pass() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index 3d268ea43..18afacf56 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -1,12 +1,9 @@ -from faulthandler import disable from functools import partial -from xml.dom import WrongDocumentErr import pytest import torch import torch.multiprocessing as mp import torch.nn as nn -from typing_extensions import Self from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import (