[fx]Split partition with DAG information (#2025)

* add DAG to split_module

* add comment

* add test case for DAG

* remove print

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang 2022-11-25 17:42:48 +08:00 committed by GitHub
parent ea0f6b8df9
commit 632753abbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 326 additions and 28 deletions

View File

@ -3,6 +3,7 @@ from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
from packaging import version from packaging import version
from colossalai.fx.passes.utils import get_DAG
import inspect import inspect
@ -38,11 +39,11 @@ def split_module(
m: GraphModule, m: GraphModule,
root_m: torch.nn.Module, root_m: torch.nn.Module,
split_callback: Callable[[torch.fx.node.Node], int], split_callback: Callable[[torch.fx.node.Node], int],
merge_output = False,
): ):
""" """
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
Creates subgraphs out of main graph Creates subgraphs out of main graph
Args: Args:
m (GraphModule): Graph module to split m (GraphModule): Graph module to split
root_m (torch.nn.Module): root nn module. Not currently used. Included root_m (torch.nn.Module): root nn module. Not currently used. Included
@ -52,52 +53,40 @@ def split_module(
that maps a given Node instance to a numeric partition identifier. that maps a given Node instance to a numeric partition identifier.
split_module will use this function as the policy for which operations split_module will use this function as the policy for which operations
appear in which partitions in the output Module. appear in which partitions in the output Module.
Returns: Returns:
GraphModule: the module after split. GraphModule: the module after split.
Example: Example:
This is a sample setup: This is a sample setup:
import torch import torch
from torch.fx.symbolic_trace import symbolic_trace from torch.fx.symbolic_trace import symbolic_trace
from torch.fx.graph_module import GraphModule from torch.fx.graph_module import GraphModule
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module from colossalai.fx.passes.split_module import split_module
class MyModule(torch.nn.Module): class MyModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4)) self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5) self.linear = torch.nn.Linear(4, 5)
def forward(self, x, y): def forward(self, x, y):
z = self.linear(x + self.param).clamp(min=0.0, max=1.0) z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
w = self.linear(y).clamp(min=0.0, max=1.0) w = self.linear(y).clamp(min=0.0, max=1.0)
return z + w return z + w
# symbolically trace model # symbolically trace model
my_module = MyModule() my_module = MyModule()
my_module_traced = symbolic_trace(my_module) my_module_traced = symbolic_trace(my_module)
# random mod partitioning # random mod partitioning
partition_counter = 0 partition_counter = 0
NPARTITIONS = 3 NPARTITIONS = 3
def mod_partition(node: Node): def mod_partition(node: Node):
global partition_counter global partition_counter
partition = partition_counter % NPARTITIONS partition = partition_counter % NPARTITIONS
partition_counter = (partition_counter + 1) % NPARTITIONS partition_counter = (partition_counter + 1) % NPARTITIONS
return partition return partition
# split module in module with submodules # split module in module with submodules
module_with_submodules = split_module( module_with_submodules = split_module(
my_module_traced, my_module, mod_partition my_module_traced, my_module, mod_partition
) )
Output looks like this. Original graph is broken into partitions Output looks like this. Original graph is broken into partitions
> print(module_with_submodules) > print(module_with_submodules)
GraphModule( GraphModule(
(submod_0): GraphModule( (submod_0): GraphModule(
@ -108,7 +97,6 @@ def split_module(
) )
(submod_2): GraphModule() (submod_2): GraphModule()
) )
def forward(self, x, y): def forward(self, x, y):
param = self.param param = self.param
submod_0 = self.submod_0(x, param, y); x = param = y = None submod_0 = self.submod_0(x, param, y); x = param = y = None
@ -119,10 +107,8 @@ def split_module(
getitem_3 = submod_1[1]; submod_1 = None getitem_3 = submod_1[1]; submod_1 = None
submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
return submod_2 return submod_2
Output of split module is the same as output of input traced module. Output of split module is the same as output of input traced module.
This is an example within a test setting: This is an example within a test setting:
> orig_out = my_module_traced(x, y) > orig_out = my_module_traced(x, y)
> submodules_out = module_with_submodules(x, y) > submodules_out = module_with_submodules(x, y)
> self.assertEqual(orig_out, submodules_out) > self.assertEqual(orig_out, submodules_out)
@ -148,6 +134,29 @@ def split_module(
if def_partition_name is not None: if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name) use_partition.partitions_dependent_on.setdefault(def_partition_name)
def record_output(
def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]
): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
if def_partition_name is not None:
def_partition = partitions[def_partition_name]
def_partition.outputs.setdefault(def_node.name)
if use_partition_name is not None:
def_partition.partition_dependents.setdefault(use_partition_name)
if use_partition_name is not None:
use_partition = partitions[use_partition_name]
use_partition.inputs.setdefault(def_node.name)
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)
use_partition.outputs.setdefault(def_node.name)
else:
if use_partition_name is not None:
use_partition = partitions[use_partition_name]
use_partition.outputs.setdefault(def_node.name)
# split nodes into parititons # split nodes into parititons
for node in m.graph.nodes: for node in m.graph.nodes:
orig_nodes[node.name] = node orig_nodes[node.name] = node
@ -155,7 +164,10 @@ def split_module(
if node.op in ["placeholder"]: if node.op in ["placeholder"]:
continue continue
if node.op == 'output': if node.op == 'output':
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None)) if merge_output:
torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev))
else:
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
continue continue
partition_name = str(split_callback(node)) partition_name = str(split_callback(node))
@ -235,10 +247,10 @@ def split_module(
for node in m.graph.nodes: for node in m.graph.nodes:
if node.op == 'placeholder': if node.op == 'placeholder':
if version.parse(torch.__version__) < version.parse('1.11.0'): if version.parse(torch.__version__) < version.parse('1.11.0'):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type) base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type)
else: else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, base_mod_env[node.name] = base_mod_graph.placeholder(node.target,
type_expr=node.type, type_expr=node.type,
default_value=default_value) default_value=default_value)
base_mod_env[node.name].meta = node.meta.copy() base_mod_env[node.name].meta = node.meta.copy()
@ -278,4 +290,15 @@ def split_module(
if node.op == 'output': if node.op == 'output':
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) for partition_name in sorted_partitions:
partition = partitions[partition_name]
new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
DAG = get_DAG(new_gm)
for _, submodule in new_gm.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_DAG', DAG)
return new_gm

View File

@ -2,7 +2,7 @@ import torch
from typing import Dict, Set from typing import Dict, Set
from torch.fx.node import Node, map_arg from torch.fx.node import Node, map_arg
from torch.fx.graph import Graph from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
def get_comm_size(prev_partition, next_partition): def get_comm_size(prev_partition, next_partition):
""" """
@ -32,7 +32,6 @@ def get_comm_size(prev_partition, next_partition):
def get_leaf(graph: Graph): def get_leaf(graph: Graph):
""" """
Given a graph, return leaf nodes of this graph. Given a graph, return leaf nodes of this graph.
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG. we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG.
""" """
@ -57,7 +56,6 @@ def is_leaf(graph: Graph, node: Node):
def get_top(graph: Graph): def get_top(graph: Graph):
""" """
Given a graph, return top nodes of this graph. Given a graph, return top nodes of this graph.
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG. we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG.
""" """
@ -100,7 +98,6 @@ def get_all_consumers(graph: Graph, node: Node):
def assign_bfs_level_to_nodes(graph: Graph): def assign_bfs_level_to_nodes(graph: Graph):
""" """
Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes. Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes.
Example: Example:
class MLP(torch.nn.Module): class MLP(torch.nn.Module):
def __init__(self, dim: int): def __init__(self, dim: int):
@ -110,8 +107,6 @@ def assign_bfs_level_to_nodes(graph: Graph):
self.linear3 = torch.nn.Linear(dim, dim) self.linear3 = torch.nn.Linear(dim, dim)
self.linear4 = torch.nn.Linear(dim, dim) self.linear4 = torch.nn.Linear(dim, dim)
self.linear5 = torch.nn.Linear(dim, dim) self.linear5 = torch.nn.Linear(dim, dim)
def forward(self, x): def forward(self, x):
l1 = self.linear1(x) l1 = self.linear1(x)
l2 = self.linear2(x) l2 = self.linear2(x)
@ -165,10 +160,8 @@ def assign_bfs_level_to_nodes(graph: Graph):
def get_node_module(node) -> torch.nn.Module: def get_node_module(node) -> torch.nn.Module:
""" """
Find the module associated with the given node. Find the module associated with the given node.
Args: Args:
node (torch.fx.Node): a torch.fx.Node object in the fx computation graph node (torch.fx.Node): a torch.fx.Node object in the fx computation graph
Returns: Returns:
torch.nn.Module: the module associated with the given node torch.nn.Module: the module associated with the given node
""" """
@ -177,3 +170,169 @@ def get_node_module(node) -> torch.nn.Module:
assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}' assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}'
module = node.graph.owning_module.get_submodule(node.target) module = node.graph.owning_module.get_submodule(node.target)
return module return module
def find_def_in_partition(node, partitions, input_partitions=None, direct=False):
# find def in input
if input_partitions is not None:
for placeholder in input_partitions:
if placeholder.name == node.name:
return 'MODEL_INPUT'
# find direct def
if direct:
for partition in partitions:
if node == partition:
return partition.name
# find def with getitem call
else:
for partition in partitions:
if node in partition.users.keys():
return partition.name
print(f'Not found def in partition {node.name}')
return None
def find_user_in_partition(node, partitions, output_partitions=None, direct=False):
user_partition_names = []
# find direct user
if direct:
for partition in partitions:
if node == partition:
user_partition_names.append(partition.name)
# find user with getitem call
else:
for partition in partitions:
if node in partition.args:
user_partition_names.append(partition.name)
is_output = False
def find_output(def_node, output_node):
nonlocal is_output
if def_node == output_node:
is_output = True
if output_partitions is not None:
output_node = output_partitions[0]
torch.fx.graph.map_arg(output_node.args[0], lambda n: find_output(node, n))
if is_output:
user_partition_names.append('MODEL_OUTPUT')
if len(user_partition_names) > 0:
return user_partition_names
print(f'Not found user in partition {node.name}')
return None
def get_partition_depends(partition, partitions, input_partitions=None, output_partitions=None):
# e.g. Partition2: {input: {Partition0: [sub1_1], Partition1: [sub2_0]}, output:{Output: [sub3_0]}},
input = {}
output = {}
for offset, arg in enumerate(partition.args):
def_partition_name = None
if not arg.name.startswith('getitem'):
def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=True)
else:
def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=False)
if def_partition_name is None:
continue
if def_partition_name not in input:
input[def_partition_name] = []
input[def_partition_name].append(offset)
offset = -1
for user in partition.users.keys():
user_partition_names = None
if input_partitions is None or not user.name.startswith('getitem'):
user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=True)
offset = 0
else:
user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=False)
offset += 1
if user_partition_names is None:
continue
for user_partition_name in user_partition_names:
if user_partition_name not in output:
output[user_partition_name] = []
output[user_partition_name].append(offset)
return input, output, offset+1
# DAG just looks like following case.
# the int in every list represents the offset of the partition's input arg or output arg.
# {
# 'input_partition': {
# 'input_ids': {
# 'input': {},
# 'output': {'submod_0': [0], 'submod_1': [1]},
# 'output_len': 0},
# 'attention_mask': {
# 'input': {},
# 'output': {'submod_2': [0]},
# 'output_len': 0}},
# 'submod_0': {
# 'input': {'MODEL_INPUT': [0]},
# 'output': {'submod_1': [0], 'submod_2': [0, 1]},
# 'output_len': 2},
# 'submod_1': {
# 'input': {'submod_0': [0], 'MODEL_INPUT': [1]},
# 'output': {'submod_2': [0]},
# 'output_len': 1},
# 'submod_2': {
# 'input': {'MODEL_INPUT': [0], 'submod_0': [1, 2]},
# 'output': {'submod_3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
# 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
# 22, 23, 24]},
# 'output_len': 25},
# 'submod_3': {
# 'input': {'submod_2': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
# 12, 13, 14, 15, 16, 17, 18, 19, 20,
# 21, 22, 23, 24]},
# 'output': {'MODEL_OUTPUT': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
# 11, 12, 13, 14, 15, 16, 17, 18, 19,
# 20, 21, 22, 23, 24]},
# 'output_len': 25},
# 'output_partition': {
# 'input': {'logits': 'submod_3', 'past_key_values': (('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'))},
# 'output': {}, 'output_len': 0}
# }
# TODO(jiangziyue) Define a Class for DAG.
def get_DAG(gm: GraphModule):
DAG = {}
input_partitions = []
partitions = []
output_partitions = []
for node in gm.graph.nodes:
if node.op == 'placeholder':
input_partitions.append(node)
elif node.name.startswith('submod_'):
partitions.append(node)
elif node.op == 'output':
output_partitions.append(node)
for partition in input_partitions:
DAG_node = {'input': {}, 'output': {}, 'output_len': 1}
_, output, _ = get_partition_depends(partition, partitions, None, output_partitions)
DAG_node['output'] = output
if 'input_partition' not in DAG:
DAG['input_partition'] = {}
DAG['input_partition'][partition.name] = DAG_node
for partition in partitions:
DAG_node = {'input': {}, 'output': {}}
DAG_node['input'], DAG_node['output'], DAG_node['output_len'] = get_partition_depends(partition, partitions, input_partitions, output_partitions)
DAG[partition.name] = DAG_node
for partition in output_partitions:
DAG_node = {'input': {}, 'output': {}, 'output_len': 0}
DAG_node['input'] = torch.fx.graph.map_arg(partition.args[0], lambda n: find_def_in_partition(n, partitions, input_partitions))
DAG['output_partition'] = DAG_node
return DAG

View File

@ -0,0 +1,85 @@
import torch
from torch.fx import GraphModule
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
import random
import numpy as np
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
def split_model_and_get_DAG(model, data_gen):
model.eval()
# generate input sample
kwargs = data_gen()
# get origin output and rng state
cpu_rng_state = torch.get_rng_state()
output = model(**kwargs)
# tracing model
tracer = ColoTracer()
try:
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# apply transform passes
annotated_model = balanced_split_pass(gm, 2)
top_module, split_submodules = split_with_split_nodes_pass(annotated_model)
return top_module, split_submodules[0]._DAG
def check_input(input, input_node, top_module):
for user in input_node.users.keys():
partition_name = user.name
assert partition_name in input['output']
def check_submod(submod_partition, node, top_module):
for arg in node.args:
input_part_name = None
if arg.op == 'placeholder':
input_part_name = 'MODEL_INPUT'
elif not arg.name.startswith('getitem'):
input_part_name = arg.name
else:
input_part_name = arg.args[0].name
assert input_part_name in submod_partition['input']
for user in node.users:
output_part_names = []
if user.op == 'output':
output_part_names.append('MODEL_OUTPUT')
elif not user.name.startswith('getitem'):
output_part_names.append(user.name)
else:
for n in user.users:
if n.op == 'output':
output_part_names.append('MODEL_OUTPUT')
else:
output_part_names.append(n.name)
for output_part_name in output_part_names:
assert output_part_name in submod_partition['output']
def check_DAG(top_module, DAG):
assert 'input_partition' in DAG
input_partition = DAG['input_partition']
for node in top_module.graph.nodes:
# check input
if node.op == 'placeholder':
assert node.name in input_partition
input = input_partition[node.name]
check_input(input, node, top_module)
elif node.op == 'call_module':
assert node.name in DAG
submod_partition = DAG[node.name]
check_submod(submod_partition, node, top_module)

View File

@ -0,0 +1,31 @@
import pytest
import torch
import transformers
from dag_utils import split_model_and_get_DAG, check_DAG
BATCH_SIZE = 1
SEQ_LENGHT = 16
@pytest.mark.skip('balance split v2 is not ready')
def test_opt():
MODEL_LIST = [
transformers.OPTModel,
#transformers.OPTForCausalLM,
]
config = transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
top_mod, DAG = split_model_and_get_DAG(model, data_gen)
check_DAG(top_mod, DAG)
if __name__ == '__main__':
test_opt()