mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[fx]Split partition with DAG information (#2025)
* add DAG to split_module * add comment * add test case for DAG * remove print Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from torch.fx.graph_module import GraphModule
|
||||
from typing import Callable, List, Dict, Any, Optional
|
||||
from torch.fx._compatibility import compatibility
|
||||
from packaging import version
|
||||
from colossalai.fx.passes.utils import get_DAG
|
||||
import inspect
|
||||
|
||||
|
||||
@@ -38,11 +39,11 @@ def split_module(
|
||||
m: GraphModule,
|
||||
root_m: torch.nn.Module,
|
||||
split_callback: Callable[[torch.fx.node.Node], int],
|
||||
merge_output = False,
|
||||
):
|
||||
"""
|
||||
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
|
||||
Creates subgraphs out of main graph
|
||||
|
||||
Args:
|
||||
m (GraphModule): Graph module to split
|
||||
root_m (torch.nn.Module): root nn module. Not currently used. Included
|
||||
@@ -52,52 +53,40 @@ def split_module(
|
||||
that maps a given Node instance to a numeric partition identifier.
|
||||
split_module will use this function as the policy for which operations
|
||||
appear in which partitions in the output Module.
|
||||
|
||||
Returns:
|
||||
GraphModule: the module after split.
|
||||
|
||||
Example:
|
||||
|
||||
This is a sample setup:
|
||||
|
||||
import torch
|
||||
from torch.fx.symbolic_trace import symbolic_trace
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node
|
||||
from colossalai.fx.passes.split_module import split_module
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
||||
self.linear = torch.nn.Linear(4, 5)
|
||||
|
||||
def forward(self, x, y):
|
||||
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
||||
w = self.linear(y).clamp(min=0.0, max=1.0)
|
||||
return z + w
|
||||
|
||||
# symbolically trace model
|
||||
my_module = MyModule()
|
||||
my_module_traced = symbolic_trace(my_module)
|
||||
|
||||
# random mod partitioning
|
||||
partition_counter = 0
|
||||
NPARTITIONS = 3
|
||||
|
||||
def mod_partition(node: Node):
|
||||
global partition_counter
|
||||
partition = partition_counter % NPARTITIONS
|
||||
partition_counter = (partition_counter + 1) % NPARTITIONS
|
||||
return partition
|
||||
|
||||
# split module in module with submodules
|
||||
module_with_submodules = split_module(
|
||||
my_module_traced, my_module, mod_partition
|
||||
)
|
||||
|
||||
Output looks like this. Original graph is broken into partitions
|
||||
|
||||
> print(module_with_submodules)
|
||||
GraphModule(
|
||||
(submod_0): GraphModule(
|
||||
@@ -108,7 +97,6 @@ def split_module(
|
||||
)
|
||||
(submod_2): GraphModule()
|
||||
)
|
||||
|
||||
def forward(self, x, y):
|
||||
param = self.param
|
||||
submod_0 = self.submod_0(x, param, y); x = param = y = None
|
||||
@@ -119,10 +107,8 @@ def split_module(
|
||||
getitem_3 = submod_1[1]; submod_1 = None
|
||||
submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
|
||||
return submod_2
|
||||
|
||||
Output of split module is the same as output of input traced module.
|
||||
This is an example within a test setting:
|
||||
|
||||
> orig_out = my_module_traced(x, y)
|
||||
> submodules_out = module_with_submodules(x, y)
|
||||
> self.assertEqual(orig_out, submodules_out)
|
||||
@@ -147,6 +133,29 @@ def split_module(
|
||||
use_partition.inputs.setdefault(def_node.name)
|
||||
if def_partition_name is not None:
|
||||
use_partition.partitions_dependent_on.setdefault(def_partition_name)
|
||||
|
||||
def record_output(
|
||||
def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]
|
||||
): # noqa: B950
|
||||
def_partition_name = getattr(def_node, "_fx_partition", None)
|
||||
use_partition_name = getattr(use_node, "_fx_partition", None)
|
||||
if def_partition_name != use_partition_name:
|
||||
if def_partition_name is not None:
|
||||
def_partition = partitions[def_partition_name]
|
||||
def_partition.outputs.setdefault(def_node.name)
|
||||
if use_partition_name is not None:
|
||||
def_partition.partition_dependents.setdefault(use_partition_name)
|
||||
|
||||
if use_partition_name is not None:
|
||||
use_partition = partitions[use_partition_name]
|
||||
use_partition.inputs.setdefault(def_node.name)
|
||||
if def_partition_name is not None:
|
||||
use_partition.partitions_dependent_on.setdefault(def_partition_name)
|
||||
use_partition.outputs.setdefault(def_node.name)
|
||||
else:
|
||||
if use_partition_name is not None:
|
||||
use_partition = partitions[use_partition_name]
|
||||
use_partition.outputs.setdefault(def_node.name)
|
||||
|
||||
# split nodes into parititons
|
||||
for node in m.graph.nodes:
|
||||
@@ -155,7 +164,10 @@ def split_module(
|
||||
if node.op in ["placeholder"]:
|
||||
continue
|
||||
if node.op == 'output':
|
||||
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
|
||||
if merge_output:
|
||||
torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev))
|
||||
else:
|
||||
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
|
||||
continue
|
||||
partition_name = str(split_callback(node))
|
||||
|
||||
@@ -235,10 +247,10 @@ def split_module(
|
||||
for node in m.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
if version.parse(torch.__version__) < version.parse('1.11.0'):
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type)
|
||||
else:
|
||||
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.target,
|
||||
type_expr=node.type,
|
||||
default_value=default_value)
|
||||
base_mod_env[node.name].meta = node.meta.copy()
|
||||
@@ -278,4 +290,15 @@ def split_module(
|
||||
if node.op == 'output':
|
||||
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
|
||||
|
||||
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
||||
for partition_name in sorted_partitions:
|
||||
partition = partitions[partition_name]
|
||||
|
||||
new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
||||
|
||||
DAG = get_DAG(new_gm)
|
||||
|
||||
for _, submodule in new_gm.named_modules():
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
setattr(submodule, '_DAG', DAG)
|
||||
|
||||
return new_gm
|
||||
|
Reference in New Issue
Block a user