[fx]Split partition with DAG information (#2025)

* add DAG to split_module

* add comment

* add test case for DAG

* remove print

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2022-11-25 17:42:48 +08:00
committed by GitHub
parent ea0f6b8df9
commit 632753abbc
4 changed files with 326 additions and 28 deletions

View File

@@ -3,6 +3,7 @@ from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility
from packaging import version
from colossalai.fx.passes.utils import get_DAG
import inspect
@@ -38,11 +39,11 @@ def split_module(
m: GraphModule,
root_m: torch.nn.Module,
split_callback: Callable[[torch.fx.node.Node], int],
merge_output = False,
):
"""
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
Creates subgraphs out of main graph
Args:
m (GraphModule): Graph module to split
root_m (torch.nn.Module): root nn module. Not currently used. Included
@@ -52,52 +53,40 @@ def split_module(
that maps a given Node instance to a numeric partition identifier.
split_module will use this function as the policy for which operations
appear in which partitions in the output Module.
Returns:
GraphModule: the module after split.
Example:
This is a sample setup:
import torch
from torch.fx.symbolic_trace import symbolic_trace
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x, y):
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
w = self.linear(y).clamp(min=0.0, max=1.0)
return z + w
# symbolically trace model
my_module = MyModule()
my_module_traced = symbolic_trace(my_module)
# random mod partitioning
partition_counter = 0
NPARTITIONS = 3
def mod_partition(node: Node):
global partition_counter
partition = partition_counter % NPARTITIONS
partition_counter = (partition_counter + 1) % NPARTITIONS
return partition
# split module in module with submodules
module_with_submodules = split_module(
my_module_traced, my_module, mod_partition
)
Output looks like this. Original graph is broken into partitions
> print(module_with_submodules)
GraphModule(
(submod_0): GraphModule(
@@ -108,7 +97,6 @@ def split_module(
)
(submod_2): GraphModule()
)
def forward(self, x, y):
param = self.param
submod_0 = self.submod_0(x, param, y); x = param = y = None
@@ -119,10 +107,8 @@ def split_module(
getitem_3 = submod_1[1]; submod_1 = None
submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
return submod_2
Output of split module is the same as output of input traced module.
This is an example within a test setting:
> orig_out = my_module_traced(x, y)
> submodules_out = module_with_submodules(x, y)
> self.assertEqual(orig_out, submodules_out)
@@ -147,6 +133,29 @@ def split_module(
use_partition.inputs.setdefault(def_node.name)
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)
def record_output(
def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]
): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
if def_partition_name is not None:
def_partition = partitions[def_partition_name]
def_partition.outputs.setdefault(def_node.name)
if use_partition_name is not None:
def_partition.partition_dependents.setdefault(use_partition_name)
if use_partition_name is not None:
use_partition = partitions[use_partition_name]
use_partition.inputs.setdefault(def_node.name)
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)
use_partition.outputs.setdefault(def_node.name)
else:
if use_partition_name is not None:
use_partition = partitions[use_partition_name]
use_partition.outputs.setdefault(def_node.name)
# split nodes into parititons
for node in m.graph.nodes:
@@ -155,7 +164,10 @@ def split_module(
if node.op in ["placeholder"]:
continue
if node.op == 'output':
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
if merge_output:
torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev))
else:
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
continue
partition_name = str(split_callback(node))
@@ -235,10 +247,10 @@ def split_module(
for node in m.graph.nodes:
if node.op == 'placeholder':
if version.parse(torch.__version__) < version.parse('1.11.0'):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
base_mod_env[node.name] = base_mod_graph.placeholder(node.target,
type_expr=node.type,
default_value=default_value)
base_mod_env[node.name].meta = node.meta.copy()
@@ -278,4 +290,15 @@ def split_module(
if node.op == 'output':
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
for partition_name in sorted_partitions:
partition = partitions[partition_name]
new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
DAG = get_DAG(new_gm)
for _, submodule in new_gm.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_DAG', DAG)
return new_gm