From 632753abbc62f15638f8595d0822e138c02fd684 Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Fri, 25 Nov 2022 17:42:48 +0800 Subject: [PATCH] [fx]Split partition with DAG information (#2025) * add DAG to split_module * add comment * add test case for DAG * remove print Co-authored-by: Ziyue Jiang --- colossalai/fx/passes/split_module.py | 63 +++++-- colossalai/fx/passes/utils.py | 175 +++++++++++++++++- .../test_pipeline/test_DAG/dag_utils.py | 85 +++++++++ .../test_pipeline/test_DAG/test_dag.py | 31 ++++ 4 files changed, 326 insertions(+), 28 deletions(-) create mode 100644 tests/test_fx/test_pipeline/test_DAG/dag_utils.py create mode 100644 tests/test_fx/test_pipeline/test_DAG/test_dag.py diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py index 8671855f4..48a76660d 100644 --- a/colossalai/fx/passes/split_module.py +++ b/colossalai/fx/passes/split_module.py @@ -3,6 +3,7 @@ from torch.fx.graph_module import GraphModule from typing import Callable, List, Dict, Any, Optional from torch.fx._compatibility import compatibility from packaging import version +from colossalai.fx.passes.utils import get_DAG import inspect @@ -38,11 +39,11 @@ def split_module( m: GraphModule, root_m: torch.nn.Module, split_callback: Callable[[torch.fx.node.Node], int], + merge_output = False, ): """ Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py Creates subgraphs out of main graph - Args: m (GraphModule): Graph module to split root_m (torch.nn.Module): root nn module. Not currently used. Included @@ -52,52 +53,40 @@ def split_module( that maps a given Node instance to a numeric partition identifier. split_module will use this function as the policy for which operations appear in which partitions in the output Module. - Returns: GraphModule: the module after split. - Example: - This is a sample setup: - import torch from torch.fx.symbolic_trace import symbolic_trace from torch.fx.graph_module import GraphModule from torch.fx.node import Node from colossalai.fx.passes.split_module import split_module - class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) - def forward(self, x, y): z = self.linear(x + self.param).clamp(min=0.0, max=1.0) w = self.linear(y).clamp(min=0.0, max=1.0) return z + w - # symbolically trace model my_module = MyModule() my_module_traced = symbolic_trace(my_module) - # random mod partitioning partition_counter = 0 NPARTITIONS = 3 - def mod_partition(node: Node): global partition_counter partition = partition_counter % NPARTITIONS partition_counter = (partition_counter + 1) % NPARTITIONS return partition - # split module in module with submodules module_with_submodules = split_module( my_module_traced, my_module, mod_partition ) - Output looks like this. Original graph is broken into partitions - > print(module_with_submodules) GraphModule( (submod_0): GraphModule( @@ -108,7 +97,6 @@ def split_module( ) (submod_2): GraphModule() ) - def forward(self, x, y): param = self.param submod_0 = self.submod_0(x, param, y); x = param = y = None @@ -119,10 +107,8 @@ def split_module( getitem_3 = submod_1[1]; submod_1 = None submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None return submod_2 - Output of split module is the same as output of input traced module. This is an example within a test setting: - > orig_out = my_module_traced(x, y) > submodules_out = module_with_submodules(x, y) > self.assertEqual(orig_out, submodules_out) @@ -147,6 +133,29 @@ def split_module( use_partition.inputs.setdefault(def_node.name) if def_partition_name is not None: use_partition.partitions_dependent_on.setdefault(def_partition_name) + + def record_output( + def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node] + ): # noqa: B950 + def_partition_name = getattr(def_node, "_fx_partition", None) + use_partition_name = getattr(use_node, "_fx_partition", None) + if def_partition_name != use_partition_name: + if def_partition_name is not None: + def_partition = partitions[def_partition_name] + def_partition.outputs.setdefault(def_node.name) + if use_partition_name is not None: + def_partition.partition_dependents.setdefault(use_partition_name) + + if use_partition_name is not None: + use_partition = partitions[use_partition_name] + use_partition.inputs.setdefault(def_node.name) + if def_partition_name is not None: + use_partition.partitions_dependent_on.setdefault(def_partition_name) + use_partition.outputs.setdefault(def_node.name) + else: + if use_partition_name is not None: + use_partition = partitions[use_partition_name] + use_partition.outputs.setdefault(def_node.name) # split nodes into parititons for node in m.graph.nodes: @@ -155,7 +164,10 @@ def split_module( if node.op in ["placeholder"]: continue if node.op == 'output': - torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None)) + if merge_output: + torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev)) + else: + torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None)) continue partition_name = str(split_callback(node)) @@ -235,10 +247,10 @@ def split_module( for node in m.graph.nodes: if node.op == 'placeholder': if version.parse(torch.__version__) < version.parse('1.11.0'): - base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type) + base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type) else: default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty - base_mod_env[node.name] = base_mod_graph.placeholder(node.name, + base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type, default_value=default_value) base_mod_env[node.name].meta = node.meta.copy() @@ -278,4 +290,15 @@ def split_module( if node.op == 'output': base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 - return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) + for partition_name in sorted_partitions: + partition = partitions[partition_name] + + new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) + + DAG = get_DAG(new_gm) + + for _, submodule in new_gm.named_modules(): + if isinstance(submodule, torch.fx.GraphModule): + setattr(submodule, '_DAG', DAG) + + return new_gm diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py index 842c9d52e..b4d3d2086 100644 --- a/colossalai/fx/passes/utils.py +++ b/colossalai/fx/passes/utils.py @@ -2,7 +2,7 @@ import torch from typing import Dict, Set from torch.fx.node import Node, map_arg from torch.fx.graph import Graph - +from torch.fx.graph_module import GraphModule def get_comm_size(prev_partition, next_partition): """ @@ -32,7 +32,6 @@ def get_comm_size(prev_partition, next_partition): def get_leaf(graph: Graph): """ Given a graph, return leaf nodes of this graph. - Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG. """ @@ -57,7 +56,6 @@ def is_leaf(graph: Graph, node: Node): def get_top(graph: Graph): """ Given a graph, return top nodes of this graph. - Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG. """ @@ -100,7 +98,6 @@ def get_all_consumers(graph: Graph, node: Node): def assign_bfs_level_to_nodes(graph: Graph): """ Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes. - Example: class MLP(torch.nn.Module): def __init__(self, dim: int): @@ -110,8 +107,6 @@ def assign_bfs_level_to_nodes(graph: Graph): self.linear3 = torch.nn.Linear(dim, dim) self.linear4 = torch.nn.Linear(dim, dim) self.linear5 = torch.nn.Linear(dim, dim) - - def forward(self, x): l1 = self.linear1(x) l2 = self.linear2(x) @@ -165,10 +160,8 @@ def assign_bfs_level_to_nodes(graph: Graph): def get_node_module(node) -> torch.nn.Module: """ Find the module associated with the given node. - Args: node (torch.fx.Node): a torch.fx.Node object in the fx computation graph - Returns: torch.nn.Module: the module associated with the given node """ @@ -177,3 +170,169 @@ def get_node_module(node) -> torch.nn.Module: assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}' module = node.graph.owning_module.get_submodule(node.target) return module + +def find_def_in_partition(node, partitions, input_partitions=None, direct=False): + # find def in input + if input_partitions is not None: + for placeholder in input_partitions: + if placeholder.name == node.name: + return 'MODEL_INPUT' + + # find direct def + if direct: + for partition in partitions: + if node == partition: + return partition.name + # find def with getitem call + else: + for partition in partitions: + if node in partition.users.keys(): + return partition.name + + print(f'Not found def in partition {node.name}') + return None + +def find_user_in_partition(node, partitions, output_partitions=None, direct=False): + user_partition_names = [] + # find direct user + if direct: + for partition in partitions: + if node == partition: + user_partition_names.append(partition.name) + # find user with getitem call + else: + for partition in partitions: + if node in partition.args: + user_partition_names.append(partition.name) + + is_output = False + def find_output(def_node, output_node): + nonlocal is_output + if def_node == output_node: + is_output = True + + if output_partitions is not None: + output_node = output_partitions[0] + torch.fx.graph.map_arg(output_node.args[0], lambda n: find_output(node, n)) + + if is_output: + user_partition_names.append('MODEL_OUTPUT') + + if len(user_partition_names) > 0: + return user_partition_names + + print(f'Not found user in partition {node.name}') + return None + +def get_partition_depends(partition, partitions, input_partitions=None, output_partitions=None): + # e.g. Partition2: {input: {Partition0: [sub1_1], Partition1: [sub2_0]}, output:{Output: [sub3_0]}}, + input = {} + output = {} + + for offset, arg in enumerate(partition.args): + def_partition_name = None + if not arg.name.startswith('getitem'): + def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=True) + else: + def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=False) + if def_partition_name is None: + continue + if def_partition_name not in input: + input[def_partition_name] = [] + input[def_partition_name].append(offset) + + offset = -1 + for user in partition.users.keys(): + user_partition_names = None + if input_partitions is None or not user.name.startswith('getitem'): + user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=True) + offset = 0 + else: + user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=False) + offset += 1 + if user_partition_names is None: + continue + for user_partition_name in user_partition_names: + if user_partition_name not in output: + output[user_partition_name] = [] + output[user_partition_name].append(offset) + + return input, output, offset+1 + +# DAG just looks like following case. +# the int in every list represents the offset of the partition's input arg or output arg. +# { +# 'input_partition': { +# 'input_ids': { +# 'input': {}, +# 'output': {'submod_0': [0], 'submod_1': [1]}, +# 'output_len': 0}, +# 'attention_mask': { +# 'input': {}, +# 'output': {'submod_2': [0]}, +# 'output_len': 0}}, +# 'submod_0': { +# 'input': {'MODEL_INPUT': [0]}, +# 'output': {'submod_1': [0], 'submod_2': [0, 1]}, +# 'output_len': 2}, +# 'submod_1': { +# 'input': {'submod_0': [0], 'MODEL_INPUT': [1]}, +# 'output': {'submod_2': [0]}, +# 'output_len': 1}, +# 'submod_2': { +# 'input': {'MODEL_INPUT': [0], 'submod_0': [1, 2]}, +# 'output': {'submod_3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, +# 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, +# 22, 23, 24]}, +# 'output_len': 25}, +# 'submod_3': { +# 'input': {'submod_2': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, +# 12, 13, 14, 15, 16, 17, 18, 19, 20, +# 21, 22, 23, 24]}, +# 'output': {'MODEL_OUTPUT': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, +# 11, 12, 13, 14, 15, 16, 17, 18, 19, +# 20, 21, 22, 23, 24]}, +# 'output_len': 25}, +# 'output_partition': { +# 'input': {'logits': 'submod_3', 'past_key_values': (('submod_3', 'submod_3'), ('submod_3', 'submod_3'), +# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'), +# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'), +# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'), +# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'), +# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'))}, +# 'output': {}, 'output_len': 0} +# } + +# TODO(jiangziyue) Define a Class for DAG. +def get_DAG(gm: GraphModule): + DAG = {} + input_partitions = [] + partitions = [] + output_partitions = [] + for node in gm.graph.nodes: + if node.op == 'placeholder': + input_partitions.append(node) + elif node.name.startswith('submod_'): + partitions.append(node) + elif node.op == 'output': + output_partitions.append(node) + + for partition in input_partitions: + DAG_node = {'input': {}, 'output': {}, 'output_len': 1} + _, output, _ = get_partition_depends(partition, partitions, None, output_partitions) + DAG_node['output'] = output + if 'input_partition' not in DAG: + DAG['input_partition'] = {} + DAG['input_partition'][partition.name] = DAG_node + + for partition in partitions: + DAG_node = {'input': {}, 'output': {}} + DAG_node['input'], DAG_node['output'], DAG_node['output_len'] = get_partition_depends(partition, partitions, input_partitions, output_partitions) + DAG[partition.name] = DAG_node + + for partition in output_partitions: + DAG_node = {'input': {}, 'output': {}, 'output_len': 0} + DAG_node['input'] = torch.fx.graph.map_arg(partition.args[0], lambda n: find_def_in_partition(n, partitions, input_partitions)) + DAG['output_partition'] = DAG_node + + return DAG \ No newline at end of file diff --git a/tests/test_fx/test_pipeline/test_DAG/dag_utils.py b/tests/test_fx/test_pipeline/test_DAG/dag_utils.py new file mode 100644 index 000000000..104296fb1 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_DAG/dag_utils.py @@ -0,0 +1,85 @@ +import torch +from torch.fx import GraphModule +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass +from colossalai.fx import ColoTracer +import random +import numpy as np + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) + +def split_model_and_get_DAG(model, data_gen): + model.eval() + + # generate input sample + kwargs = data_gen() + + # get origin output and rng state + cpu_rng_state = torch.get_rng_state() + output = model(**kwargs) + + # tracing model + tracer = ColoTracer() + try: + meta_args = {k: v.to('meta') for k, v in kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + except Exception as e: + raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # apply transform passes + annotated_model = balanced_split_pass(gm, 2) + top_module, split_submodules = split_with_split_nodes_pass(annotated_model) + + return top_module, split_submodules[0]._DAG + +def check_input(input, input_node, top_module): + for user in input_node.users.keys(): + partition_name = user.name + assert partition_name in input['output'] + +def check_submod(submod_partition, node, top_module): + for arg in node.args: + input_part_name = None + if arg.op == 'placeholder': + input_part_name = 'MODEL_INPUT' + elif not arg.name.startswith('getitem'): + input_part_name = arg.name + else: + input_part_name = arg.args[0].name + assert input_part_name in submod_partition['input'] + + for user in node.users: + output_part_names = [] + if user.op == 'output': + output_part_names.append('MODEL_OUTPUT') + elif not user.name.startswith('getitem'): + output_part_names.append(user.name) + else: + for n in user.users: + if n.op == 'output': + output_part_names.append('MODEL_OUTPUT') + else: + output_part_names.append(n.name) + + for output_part_name in output_part_names: + assert output_part_name in submod_partition['output'] + +def check_DAG(top_module, DAG): + assert 'input_partition' in DAG + input_partition = DAG['input_partition'] + + for node in top_module.graph.nodes: + # check input + if node.op == 'placeholder': + assert node.name in input_partition + input = input_partition[node.name] + check_input(input, node, top_module) + elif node.op == 'call_module': + assert node.name in DAG + submod_partition = DAG[node.name] + check_submod(submod_partition, node, top_module) + \ No newline at end of file diff --git a/tests/test_fx/test_pipeline/test_DAG/test_dag.py b/tests/test_fx/test_pipeline/test_DAG/test_dag.py new file mode 100644 index 000000000..7f7caa36e --- /dev/null +++ b/tests/test_fx/test_pipeline/test_DAG/test_dag.py @@ -0,0 +1,31 @@ +import pytest +import torch +import transformers +from dag_utils import split_model_and_get_DAG, check_DAG + +BATCH_SIZE = 1 +SEQ_LENGHT = 16 + + +@pytest.mark.skip('balance split v2 is not ready') +def test_opt(): + MODEL_LIST = [ + transformers.OPTModel, + #transformers.OPTForCausalLM, + ] + + config = transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + top_mod, DAG = split_model_and_get_DAG(model, data_gen) + check_DAG(top_mod, DAG) + +if __name__ == '__main__': + test_opt() \ No newline at end of file