[fx]Split partition with DAG information (#2025)

* add DAG to split_module

* add comment

* add test case for DAG

* remove print

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2022-11-25 17:42:48 +08:00
committed by GitHub
parent ea0f6b8df9
commit 632753abbc
4 changed files with 326 additions and 28 deletions

View File

@@ -2,7 +2,7 @@ import torch
from typing import Dict, Set
from torch.fx.node import Node, map_arg
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
def get_comm_size(prev_partition, next_partition):
"""
@@ -32,7 +32,6 @@ def get_comm_size(prev_partition, next_partition):
def get_leaf(graph: Graph):
"""
Given a graph, return leaf nodes of this graph.
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG.
"""
@@ -57,7 +56,6 @@ def is_leaf(graph: Graph, node: Node):
def get_top(graph: Graph):
"""
Given a graph, return top nodes of this graph.
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG.
"""
@@ -100,7 +98,6 @@ def get_all_consumers(graph: Graph, node: Node):
def assign_bfs_level_to_nodes(graph: Graph):
"""
Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes.
Example:
class MLP(torch.nn.Module):
def __init__(self, dim: int):
@@ -110,8 +107,6 @@ def assign_bfs_level_to_nodes(graph: Graph):
self.linear3 = torch.nn.Linear(dim, dim)
self.linear4 = torch.nn.Linear(dim, dim)
self.linear5 = torch.nn.Linear(dim, dim)
def forward(self, x):
l1 = self.linear1(x)
l2 = self.linear2(x)
@@ -165,10 +160,8 @@ def assign_bfs_level_to_nodes(graph: Graph):
def get_node_module(node) -> torch.nn.Module:
"""
Find the module associated with the given node.
Args:
node (torch.fx.Node): a torch.fx.Node object in the fx computation graph
Returns:
torch.nn.Module: the module associated with the given node
"""
@@ -177,3 +170,169 @@ def get_node_module(node) -> torch.nn.Module:
assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}'
module = node.graph.owning_module.get_submodule(node.target)
return module
def find_def_in_partition(node, partitions, input_partitions=None, direct=False):
# find def in input
if input_partitions is not None:
for placeholder in input_partitions:
if placeholder.name == node.name:
return 'MODEL_INPUT'
# find direct def
if direct:
for partition in partitions:
if node == partition:
return partition.name
# find def with getitem call
else:
for partition in partitions:
if node in partition.users.keys():
return partition.name
print(f'Not found def in partition {node.name}')
return None
def find_user_in_partition(node, partitions, output_partitions=None, direct=False):
user_partition_names = []
# find direct user
if direct:
for partition in partitions:
if node == partition:
user_partition_names.append(partition.name)
# find user with getitem call
else:
for partition in partitions:
if node in partition.args:
user_partition_names.append(partition.name)
is_output = False
def find_output(def_node, output_node):
nonlocal is_output
if def_node == output_node:
is_output = True
if output_partitions is not None:
output_node = output_partitions[0]
torch.fx.graph.map_arg(output_node.args[0], lambda n: find_output(node, n))
if is_output:
user_partition_names.append('MODEL_OUTPUT')
if len(user_partition_names) > 0:
return user_partition_names
print(f'Not found user in partition {node.name}')
return None
def get_partition_depends(partition, partitions, input_partitions=None, output_partitions=None):
# e.g. Partition2: {input: {Partition0: [sub1_1], Partition1: [sub2_0]}, output:{Output: [sub3_0]}},
input = {}
output = {}
for offset, arg in enumerate(partition.args):
def_partition_name = None
if not arg.name.startswith('getitem'):
def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=True)
else:
def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=False)
if def_partition_name is None:
continue
if def_partition_name not in input:
input[def_partition_name] = []
input[def_partition_name].append(offset)
offset = -1
for user in partition.users.keys():
user_partition_names = None
if input_partitions is None or not user.name.startswith('getitem'):
user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=True)
offset = 0
else:
user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=False)
offset += 1
if user_partition_names is None:
continue
for user_partition_name in user_partition_names:
if user_partition_name not in output:
output[user_partition_name] = []
output[user_partition_name].append(offset)
return input, output, offset+1
# DAG just looks like following case.
# the int in every list represents the offset of the partition's input arg or output arg.
# {
# 'input_partition': {
# 'input_ids': {
# 'input': {},
# 'output': {'submod_0': [0], 'submod_1': [1]},
# 'output_len': 0},
# 'attention_mask': {
# 'input': {},
# 'output': {'submod_2': [0]},
# 'output_len': 0}},
# 'submod_0': {
# 'input': {'MODEL_INPUT': [0]},
# 'output': {'submod_1': [0], 'submod_2': [0, 1]},
# 'output_len': 2},
# 'submod_1': {
# 'input': {'submod_0': [0], 'MODEL_INPUT': [1]},
# 'output': {'submod_2': [0]},
# 'output_len': 1},
# 'submod_2': {
# 'input': {'MODEL_INPUT': [0], 'submod_0': [1, 2]},
# 'output': {'submod_3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
# 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
# 22, 23, 24]},
# 'output_len': 25},
# 'submod_3': {
# 'input': {'submod_2': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
# 12, 13, 14, 15, 16, 17, 18, 19, 20,
# 21, 22, 23, 24]},
# 'output': {'MODEL_OUTPUT': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
# 11, 12, 13, 14, 15, 16, 17, 18, 19,
# 20, 21, 22, 23, 24]},
# 'output_len': 25},
# 'output_partition': {
# 'input': {'logits': 'submod_3', 'past_key_values': (('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'))},
# 'output': {}, 'output_len': 0}
# }
# TODO(jiangziyue) Define a Class for DAG.
def get_DAG(gm: GraphModule):
DAG = {}
input_partitions = []
partitions = []
output_partitions = []
for node in gm.graph.nodes:
if node.op == 'placeholder':
input_partitions.append(node)
elif node.name.startswith('submod_'):
partitions.append(node)
elif node.op == 'output':
output_partitions.append(node)
for partition in input_partitions:
DAG_node = {'input': {}, 'output': {}, 'output_len': 1}
_, output, _ = get_partition_depends(partition, partitions, None, output_partitions)
DAG_node['output'] = output
if 'input_partition' not in DAG:
DAG['input_partition'] = {}
DAG['input_partition'][partition.name] = DAG_node
for partition in partitions:
DAG_node = {'input': {}, 'output': {}}
DAG_node['input'], DAG_node['output'], DAG_node['output_len'] = get_partition_depends(partition, partitions, input_partitions, output_partitions)
DAG[partition.name] = DAG_node
for partition in output_partitions:
DAG_node = {'input': {}, 'output': {}, 'output_len': 0}
DAG_node['input'] = torch.fx.graph.map_arg(partition.args[0], lambda n: find_def_in_partition(n, partitions, input_partitions))
DAG['output_partition'] = DAG_node
return DAG