From 52bc2dc271b4a0046cbe840875d298b56e49aa19 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 27 Jul 2022 13:40:54 +0800 Subject: [PATCH] [fx] update split module pass and add customized policy (#1373) * [CLI] add CLI launcher * Revert "[CLI] add CLI launcher" This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c. * [fx]update split module pass and add customized policy --- .../fx/passes/adding_split_node_pass.py | 2 + colossalai/fx/passes/passes_for_gpt2_test.py | 103 ++++++++++++++---- 2 files changed, 85 insertions(+), 20 deletions(-) diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py index e2ea6ec70..4013d79f7 100644 --- a/colossalai/fx/passes/adding_split_node_pass.py +++ b/colossalai/fx/passes/adding_split_node_pass.py @@ -61,6 +61,8 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int): for node in mod_graph.nodes: if pp_size <= 1: break + if 'pipe_split' in node.name: + continue accumulate_node_size += node.node_size if accumulate_node_size >= partition_size: accumulate_node_size = 0 diff --git a/colossalai/fx/passes/passes_for_gpt2_test.py b/colossalai/fx/passes/passes_for_gpt2_test.py index d51a4de59..93aa7fb99 100644 --- a/colossalai/fx/passes/passes_for_gpt2_test.py +++ b/colossalai/fx/passes/passes_for_gpt2_test.py @@ -5,11 +5,45 @@ from torch.fx._compatibility import compatibility from packaging import version from colossalai.fx.passes.meta_info_prop import TensorMetadata import inspect +from typing import List from colossalai.fx.passes.split_module import Partition -from colossalai.fx.passes.adding_split_node_pass import pipe_split +from colossalai.fx.passes.adding_split_node_pass import pipe_split, balanced_split_pass from torch.fx.node import Node +def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]): + ''' + This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future. + ''' + mod_graph = gm.graph + valid_children_size = 0 + valid_children = [] + for node in mod_graph.nodes: + if node.op == "call_module": + valid_children_size += 1 + valid_children.append(node.target) + if valid_children_size < pp_size: + # If valid children is not enough to shard, we will use balanced policy instead of uniform policy. + return balanced_split_pass(gm, pp_size) + accumulate_layer_amount = 0 + list_of_part = partition_list + part_index = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if node.op == "call_module": + if node.target in valid_children: + accumulate_layer_amount += 1 + if accumulate_layer_amount == list_of_part[part_index]: + part_index += 1 + pp_size -= 1 + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + + gm.recompile() + return gm + + def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule): ''' This pass will be used in gpt2 test, only a part of changes may be added into @@ -25,21 +59,40 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule) gm.recompile() return gm - def eliminate_unused_outputs(gm, next_partition_placeholders): + def refill_outputs_and_placeholders(gm, next_partition_placeholders): ''' This method is used to eliminate the outputs in previous partition which is unused in next partition. + In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel. + The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it + to partition 1 and partition 2. However, in single direction linked list, we need to do so. ''' + output_type = None + output_args = [] + non_output_list = [] + new_placeholder_list = [] for node in gm.graph.nodes: if node.op == 'output': output_type = node.args[0].__class__ - output_args = list(node.args[0]) + output_args.extend(list(node.args[0])) for n in node.args[0]: - if n.name not in next_partition_placeholders: + if next_partition_placeholders and n not in next_partition_placeholders: output_args.remove(n) gm.graph.erase_node(node) + else: + non_output_list.append(node.name) + for node in next_partition_placeholders: + if node not in output_args: + output_args.append(node) + for node in output_args: + if node.name not in non_output_list: + gm.graph.placeholder(node.name) + + for node in gm.graph.nodes: + if node.op == 'placeholder': + new_placeholder_list.append(node) gm.graph.output(output_type(output_args)) gm.recompile() - return gm + return gm, new_placeholder_list def split_callback(n: torch.fx.Node): nonlocal part_idx @@ -64,13 +117,22 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule) placeholder_dict[submodule] = [] for node in submodule.graph.nodes: if node.op == 'placeholder': - placeholder_dict[submodule].append(node.name) - + placeholder_dict[submodule].append(node) + output_dict = {} + for submodule in submodules: + output_dict[submodule] = [] + for node in submodule.graph.nodes: + if node.op == 'output': + output_dict[submodule].append(node.name) + submodules.reverse() for index, submodule in enumerate(submodules): - if index >= len(submodules) - 1: - break - submodule = eliminate_unused_outputs(submodule, placeholder_dict[submodules[index + 1]]) + if index == 0: + placeholder_list = [] + else: + placeholder_list = placeholder_dict[submodules[index - 1]] + submodule, placeholder_dict[submodule] = refill_outputs_and_placeholders(submodule, placeholder_list) submodule.recompile() + split_mod.recompile() return split_mod, split_submodules @@ -118,7 +180,7 @@ def split_module_for_gpt2_test( _gen_all_ancestors_set(node) for n in list(all_ancestors): - if n.op != 'placeholder': + if n.op != 'placeholder' and n._fx_partition > partition_name: n._fx_partition = partition_name def record_cross_partition_use(def_node: torch.fx.node.Node, @@ -126,14 +188,14 @@ def split_module_for_gpt2_test( def_partition_name = getattr(def_node, '_fx_partition', None) use_partition_name = getattr(use_node, '_fx_partition', None) if def_partition_name != use_partition_name: - if 'tensor_meta' in def_node.meta: - if not _node_with_all_tensor_element(def_node.meta['tensor_meta']): - _move_all_ancestors_into_partition(use_node, def_partition_name) - node_process_list.extend(use_node.all_input_nodes) - node_process_list.extend(list(use_node.users)) - node_process_list.append(use_node) + # if 'tensor_meta' in def_node.meta: + # if not _node_with_all_tensor_element(def_node.meta['tensor_meta']): + # _move_all_ancestors_into_partition(use_node, def_partition_name) + # node_process_list.extend(use_node.all_input_nodes) + # node_process_list.extend(list(use_node.users)) + # node_process_list.append(use_node) - return + # return if def_partition_name is not None: def_partition = partitions[def_partition_name] @@ -231,10 +293,11 @@ def split_module_for_gpt2_test( new_node = partition.graph.create_node(op=node.op, target=target, args=gathered_args, - kwargs=gathered_kwargs) + kwargs=gathered_kwargs, + name=node.name) new_node.meta = node.meta.copy() partition.environment[node] = new_node - + assert 'add_85' in orig_nodes # Set up values to construct base module base_mod_env: Dict[str, torch.fx.node.Node] = {} base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()