mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-20 00:47:13 +00:00
[fx]Split partition with DAG information (#2025)
* add DAG to split_module * add comment * add test case for DAG * remove print Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
parent
ea0f6b8df9
commit
632753abbc
@ -3,6 +3,7 @@ from torch.fx.graph_module import GraphModule
|
|||||||
from typing import Callable, List, Dict, Any, Optional
|
from typing import Callable, List, Dict, Any, Optional
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from colossalai.fx.passes.utils import get_DAG
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
@ -38,11 +39,11 @@ def split_module(
|
|||||||
m: GraphModule,
|
m: GraphModule,
|
||||||
root_m: torch.nn.Module,
|
root_m: torch.nn.Module,
|
||||||
split_callback: Callable[[torch.fx.node.Node], int],
|
split_callback: Callable[[torch.fx.node.Node], int],
|
||||||
|
merge_output = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
|
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
|
||||||
Creates subgraphs out of main graph
|
Creates subgraphs out of main graph
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
m (GraphModule): Graph module to split
|
m (GraphModule): Graph module to split
|
||||||
root_m (torch.nn.Module): root nn module. Not currently used. Included
|
root_m (torch.nn.Module): root nn module. Not currently used. Included
|
||||||
@ -52,52 +53,40 @@ def split_module(
|
|||||||
that maps a given Node instance to a numeric partition identifier.
|
that maps a given Node instance to a numeric partition identifier.
|
||||||
split_module will use this function as the policy for which operations
|
split_module will use this function as the policy for which operations
|
||||||
appear in which partitions in the output Module.
|
appear in which partitions in the output Module.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
GraphModule: the module after split.
|
GraphModule: the module after split.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
This is a sample setup:
|
This is a sample setup:
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.symbolic_trace import symbolic_trace
|
from torch.fx.symbolic_trace import symbolic_trace
|
||||||
from torch.fx.graph_module import GraphModule
|
from torch.fx.graph_module import GraphModule
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
from colossalai.fx.passes.split_module import split_module
|
from colossalai.fx.passes.split_module import split_module
|
||||||
|
|
||||||
class MyModule(torch.nn.Module):
|
class MyModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
||||||
self.linear = torch.nn.Linear(4, 5)
|
self.linear = torch.nn.Linear(4, 5)
|
||||||
|
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
||||||
w = self.linear(y).clamp(min=0.0, max=1.0)
|
w = self.linear(y).clamp(min=0.0, max=1.0)
|
||||||
return z + w
|
return z + w
|
||||||
|
|
||||||
# symbolically trace model
|
# symbolically trace model
|
||||||
my_module = MyModule()
|
my_module = MyModule()
|
||||||
my_module_traced = symbolic_trace(my_module)
|
my_module_traced = symbolic_trace(my_module)
|
||||||
|
|
||||||
# random mod partitioning
|
# random mod partitioning
|
||||||
partition_counter = 0
|
partition_counter = 0
|
||||||
NPARTITIONS = 3
|
NPARTITIONS = 3
|
||||||
|
|
||||||
def mod_partition(node: Node):
|
def mod_partition(node: Node):
|
||||||
global partition_counter
|
global partition_counter
|
||||||
partition = partition_counter % NPARTITIONS
|
partition = partition_counter % NPARTITIONS
|
||||||
partition_counter = (partition_counter + 1) % NPARTITIONS
|
partition_counter = (partition_counter + 1) % NPARTITIONS
|
||||||
return partition
|
return partition
|
||||||
|
|
||||||
# split module in module with submodules
|
# split module in module with submodules
|
||||||
module_with_submodules = split_module(
|
module_with_submodules = split_module(
|
||||||
my_module_traced, my_module, mod_partition
|
my_module_traced, my_module, mod_partition
|
||||||
)
|
)
|
||||||
|
|
||||||
Output looks like this. Original graph is broken into partitions
|
Output looks like this. Original graph is broken into partitions
|
||||||
|
|
||||||
> print(module_with_submodules)
|
> print(module_with_submodules)
|
||||||
GraphModule(
|
GraphModule(
|
||||||
(submod_0): GraphModule(
|
(submod_0): GraphModule(
|
||||||
@ -108,7 +97,6 @@ def split_module(
|
|||||||
)
|
)
|
||||||
(submod_2): GraphModule()
|
(submod_2): GraphModule()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
param = self.param
|
param = self.param
|
||||||
submod_0 = self.submod_0(x, param, y); x = param = y = None
|
submod_0 = self.submod_0(x, param, y); x = param = y = None
|
||||||
@ -119,10 +107,8 @@ def split_module(
|
|||||||
getitem_3 = submod_1[1]; submod_1 = None
|
getitem_3 = submod_1[1]; submod_1 = None
|
||||||
submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
|
submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
|
||||||
return submod_2
|
return submod_2
|
||||||
|
|
||||||
Output of split module is the same as output of input traced module.
|
Output of split module is the same as output of input traced module.
|
||||||
This is an example within a test setting:
|
This is an example within a test setting:
|
||||||
|
|
||||||
> orig_out = my_module_traced(x, y)
|
> orig_out = my_module_traced(x, y)
|
||||||
> submodules_out = module_with_submodules(x, y)
|
> submodules_out = module_with_submodules(x, y)
|
||||||
> self.assertEqual(orig_out, submodules_out)
|
> self.assertEqual(orig_out, submodules_out)
|
||||||
@ -148,6 +134,29 @@ def split_module(
|
|||||||
if def_partition_name is not None:
|
if def_partition_name is not None:
|
||||||
use_partition.partitions_dependent_on.setdefault(def_partition_name)
|
use_partition.partitions_dependent_on.setdefault(def_partition_name)
|
||||||
|
|
||||||
|
def record_output(
|
||||||
|
def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]
|
||||||
|
): # noqa: B950
|
||||||
|
def_partition_name = getattr(def_node, "_fx_partition", None)
|
||||||
|
use_partition_name = getattr(use_node, "_fx_partition", None)
|
||||||
|
if def_partition_name != use_partition_name:
|
||||||
|
if def_partition_name is not None:
|
||||||
|
def_partition = partitions[def_partition_name]
|
||||||
|
def_partition.outputs.setdefault(def_node.name)
|
||||||
|
if use_partition_name is not None:
|
||||||
|
def_partition.partition_dependents.setdefault(use_partition_name)
|
||||||
|
|
||||||
|
if use_partition_name is not None:
|
||||||
|
use_partition = partitions[use_partition_name]
|
||||||
|
use_partition.inputs.setdefault(def_node.name)
|
||||||
|
if def_partition_name is not None:
|
||||||
|
use_partition.partitions_dependent_on.setdefault(def_partition_name)
|
||||||
|
use_partition.outputs.setdefault(def_node.name)
|
||||||
|
else:
|
||||||
|
if use_partition_name is not None:
|
||||||
|
use_partition = partitions[use_partition_name]
|
||||||
|
use_partition.outputs.setdefault(def_node.name)
|
||||||
|
|
||||||
# split nodes into parititons
|
# split nodes into parititons
|
||||||
for node in m.graph.nodes:
|
for node in m.graph.nodes:
|
||||||
orig_nodes[node.name] = node
|
orig_nodes[node.name] = node
|
||||||
@ -155,7 +164,10 @@ def split_module(
|
|||||||
if node.op in ["placeholder"]:
|
if node.op in ["placeholder"]:
|
||||||
continue
|
continue
|
||||||
if node.op == 'output':
|
if node.op == 'output':
|
||||||
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
|
if merge_output:
|
||||||
|
torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev))
|
||||||
|
else:
|
||||||
|
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
|
||||||
continue
|
continue
|
||||||
partition_name = str(split_callback(node))
|
partition_name = str(split_callback(node))
|
||||||
|
|
||||||
@ -235,10 +247,10 @@ def split_module(
|
|||||||
for node in m.graph.nodes:
|
for node in m.graph.nodes:
|
||||||
if node.op == 'placeholder':
|
if node.op == 'placeholder':
|
||||||
if version.parse(torch.__version__) < version.parse('1.11.0'):
|
if version.parse(torch.__version__) < version.parse('1.11.0'):
|
||||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
|
base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type)
|
||||||
else:
|
else:
|
||||||
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
|
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
|
||||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
|
base_mod_env[node.name] = base_mod_graph.placeholder(node.target,
|
||||||
type_expr=node.type,
|
type_expr=node.type,
|
||||||
default_value=default_value)
|
default_value=default_value)
|
||||||
base_mod_env[node.name].meta = node.meta.copy()
|
base_mod_env[node.name].meta = node.meta.copy()
|
||||||
@ -278,4 +290,15 @@ def split_module(
|
|||||||
if node.op == 'output':
|
if node.op == 'output':
|
||||||
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
|
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
|
||||||
|
|
||||||
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
for partition_name in sorted_partitions:
|
||||||
|
partition = partitions[partition_name]
|
||||||
|
|
||||||
|
new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
||||||
|
|
||||||
|
DAG = get_DAG(new_gm)
|
||||||
|
|
||||||
|
for _, submodule in new_gm.named_modules():
|
||||||
|
if isinstance(submodule, torch.fx.GraphModule):
|
||||||
|
setattr(submodule, '_DAG', DAG)
|
||||||
|
|
||||||
|
return new_gm
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
from typing import Dict, Set
|
from typing import Dict, Set
|
||||||
from torch.fx.node import Node, map_arg
|
from torch.fx.node import Node, map_arg
|
||||||
from torch.fx.graph import Graph
|
from torch.fx.graph import Graph
|
||||||
|
from torch.fx.graph_module import GraphModule
|
||||||
|
|
||||||
def get_comm_size(prev_partition, next_partition):
|
def get_comm_size(prev_partition, next_partition):
|
||||||
"""
|
"""
|
||||||
@ -32,7 +32,6 @@ def get_comm_size(prev_partition, next_partition):
|
|||||||
def get_leaf(graph: Graph):
|
def get_leaf(graph: Graph):
|
||||||
"""
|
"""
|
||||||
Given a graph, return leaf nodes of this graph.
|
Given a graph, return leaf nodes of this graph.
|
||||||
|
|
||||||
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
|
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
|
||||||
we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG.
|
we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG.
|
||||||
"""
|
"""
|
||||||
@ -57,7 +56,6 @@ def is_leaf(graph: Graph, node: Node):
|
|||||||
def get_top(graph: Graph):
|
def get_top(graph: Graph):
|
||||||
"""
|
"""
|
||||||
Given a graph, return top nodes of this graph.
|
Given a graph, return top nodes of this graph.
|
||||||
|
|
||||||
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
|
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
|
||||||
we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG.
|
we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG.
|
||||||
"""
|
"""
|
||||||
@ -100,7 +98,6 @@ def get_all_consumers(graph: Graph, node: Node):
|
|||||||
def assign_bfs_level_to_nodes(graph: Graph):
|
def assign_bfs_level_to_nodes(graph: Graph):
|
||||||
"""
|
"""
|
||||||
Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes.
|
Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
class MLP(torch.nn.Module):
|
class MLP(torch.nn.Module):
|
||||||
def __init__(self, dim: int):
|
def __init__(self, dim: int):
|
||||||
@ -110,8 +107,6 @@ def assign_bfs_level_to_nodes(graph: Graph):
|
|||||||
self.linear3 = torch.nn.Linear(dim, dim)
|
self.linear3 = torch.nn.Linear(dim, dim)
|
||||||
self.linear4 = torch.nn.Linear(dim, dim)
|
self.linear4 = torch.nn.Linear(dim, dim)
|
||||||
self.linear5 = torch.nn.Linear(dim, dim)
|
self.linear5 = torch.nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
l1 = self.linear1(x)
|
l1 = self.linear1(x)
|
||||||
l2 = self.linear2(x)
|
l2 = self.linear2(x)
|
||||||
@ -165,10 +160,8 @@ def assign_bfs_level_to_nodes(graph: Graph):
|
|||||||
def get_node_module(node) -> torch.nn.Module:
|
def get_node_module(node) -> torch.nn.Module:
|
||||||
"""
|
"""
|
||||||
Find the module associated with the given node.
|
Find the module associated with the given node.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node (torch.fx.Node): a torch.fx.Node object in the fx computation graph
|
node (torch.fx.Node): a torch.fx.Node object in the fx computation graph
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.nn.Module: the module associated with the given node
|
torch.nn.Module: the module associated with the given node
|
||||||
"""
|
"""
|
||||||
@ -177,3 +170,169 @@ def get_node_module(node) -> torch.nn.Module:
|
|||||||
assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}'
|
assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}'
|
||||||
module = node.graph.owning_module.get_submodule(node.target)
|
module = node.graph.owning_module.get_submodule(node.target)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
def find_def_in_partition(node, partitions, input_partitions=None, direct=False):
|
||||||
|
# find def in input
|
||||||
|
if input_partitions is not None:
|
||||||
|
for placeholder in input_partitions:
|
||||||
|
if placeholder.name == node.name:
|
||||||
|
return 'MODEL_INPUT'
|
||||||
|
|
||||||
|
# find direct def
|
||||||
|
if direct:
|
||||||
|
for partition in partitions:
|
||||||
|
if node == partition:
|
||||||
|
return partition.name
|
||||||
|
# find def with getitem call
|
||||||
|
else:
|
||||||
|
for partition in partitions:
|
||||||
|
if node in partition.users.keys():
|
||||||
|
return partition.name
|
||||||
|
|
||||||
|
print(f'Not found def in partition {node.name}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_user_in_partition(node, partitions, output_partitions=None, direct=False):
|
||||||
|
user_partition_names = []
|
||||||
|
# find direct user
|
||||||
|
if direct:
|
||||||
|
for partition in partitions:
|
||||||
|
if node == partition:
|
||||||
|
user_partition_names.append(partition.name)
|
||||||
|
# find user with getitem call
|
||||||
|
else:
|
||||||
|
for partition in partitions:
|
||||||
|
if node in partition.args:
|
||||||
|
user_partition_names.append(partition.name)
|
||||||
|
|
||||||
|
is_output = False
|
||||||
|
def find_output(def_node, output_node):
|
||||||
|
nonlocal is_output
|
||||||
|
if def_node == output_node:
|
||||||
|
is_output = True
|
||||||
|
|
||||||
|
if output_partitions is not None:
|
||||||
|
output_node = output_partitions[0]
|
||||||
|
torch.fx.graph.map_arg(output_node.args[0], lambda n: find_output(node, n))
|
||||||
|
|
||||||
|
if is_output:
|
||||||
|
user_partition_names.append('MODEL_OUTPUT')
|
||||||
|
|
||||||
|
if len(user_partition_names) > 0:
|
||||||
|
return user_partition_names
|
||||||
|
|
||||||
|
print(f'Not found user in partition {node.name}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_partition_depends(partition, partitions, input_partitions=None, output_partitions=None):
|
||||||
|
# e.g. Partition2: {input: {Partition0: [sub1_1], Partition1: [sub2_0]}, output:{Output: [sub3_0]}},
|
||||||
|
input = {}
|
||||||
|
output = {}
|
||||||
|
|
||||||
|
for offset, arg in enumerate(partition.args):
|
||||||
|
def_partition_name = None
|
||||||
|
if not arg.name.startswith('getitem'):
|
||||||
|
def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=True)
|
||||||
|
else:
|
||||||
|
def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=False)
|
||||||
|
if def_partition_name is None:
|
||||||
|
continue
|
||||||
|
if def_partition_name not in input:
|
||||||
|
input[def_partition_name] = []
|
||||||
|
input[def_partition_name].append(offset)
|
||||||
|
|
||||||
|
offset = -1
|
||||||
|
for user in partition.users.keys():
|
||||||
|
user_partition_names = None
|
||||||
|
if input_partitions is None or not user.name.startswith('getitem'):
|
||||||
|
user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=True)
|
||||||
|
offset = 0
|
||||||
|
else:
|
||||||
|
user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=False)
|
||||||
|
offset += 1
|
||||||
|
if user_partition_names is None:
|
||||||
|
continue
|
||||||
|
for user_partition_name in user_partition_names:
|
||||||
|
if user_partition_name not in output:
|
||||||
|
output[user_partition_name] = []
|
||||||
|
output[user_partition_name].append(offset)
|
||||||
|
|
||||||
|
return input, output, offset+1
|
||||||
|
|
||||||
|
# DAG just looks like following case.
|
||||||
|
# the int in every list represents the offset of the partition's input arg or output arg.
|
||||||
|
# {
|
||||||
|
# 'input_partition': {
|
||||||
|
# 'input_ids': {
|
||||||
|
# 'input': {},
|
||||||
|
# 'output': {'submod_0': [0], 'submod_1': [1]},
|
||||||
|
# 'output_len': 0},
|
||||||
|
# 'attention_mask': {
|
||||||
|
# 'input': {},
|
||||||
|
# 'output': {'submod_2': [0]},
|
||||||
|
# 'output_len': 0}},
|
||||||
|
# 'submod_0': {
|
||||||
|
# 'input': {'MODEL_INPUT': [0]},
|
||||||
|
# 'output': {'submod_1': [0], 'submod_2': [0, 1]},
|
||||||
|
# 'output_len': 2},
|
||||||
|
# 'submod_1': {
|
||||||
|
# 'input': {'submod_0': [0], 'MODEL_INPUT': [1]},
|
||||||
|
# 'output': {'submod_2': [0]},
|
||||||
|
# 'output_len': 1},
|
||||||
|
# 'submod_2': {
|
||||||
|
# 'input': {'MODEL_INPUT': [0], 'submod_0': [1, 2]},
|
||||||
|
# 'output': {'submod_3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||||
|
# 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
|
||||||
|
# 22, 23, 24]},
|
||||||
|
# 'output_len': 25},
|
||||||
|
# 'submod_3': {
|
||||||
|
# 'input': {'submod_2': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||||
|
# 12, 13, 14, 15, 16, 17, 18, 19, 20,
|
||||||
|
# 21, 22, 23, 24]},
|
||||||
|
# 'output': {'MODEL_OUTPUT': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
|
||||||
|
# 11, 12, 13, 14, 15, 16, 17, 18, 19,
|
||||||
|
# 20, 21, 22, 23, 24]},
|
||||||
|
# 'output_len': 25},
|
||||||
|
# 'output_partition': {
|
||||||
|
# 'input': {'logits': 'submod_3', 'past_key_values': (('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
|
||||||
|
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
|
||||||
|
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
|
||||||
|
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
|
||||||
|
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
|
||||||
|
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'))},
|
||||||
|
# 'output': {}, 'output_len': 0}
|
||||||
|
# }
|
||||||
|
|
||||||
|
# TODO(jiangziyue) Define a Class for DAG.
|
||||||
|
def get_DAG(gm: GraphModule):
|
||||||
|
DAG = {}
|
||||||
|
input_partitions = []
|
||||||
|
partitions = []
|
||||||
|
output_partitions = []
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
if node.op == 'placeholder':
|
||||||
|
input_partitions.append(node)
|
||||||
|
elif node.name.startswith('submod_'):
|
||||||
|
partitions.append(node)
|
||||||
|
elif node.op == 'output':
|
||||||
|
output_partitions.append(node)
|
||||||
|
|
||||||
|
for partition in input_partitions:
|
||||||
|
DAG_node = {'input': {}, 'output': {}, 'output_len': 1}
|
||||||
|
_, output, _ = get_partition_depends(partition, partitions, None, output_partitions)
|
||||||
|
DAG_node['output'] = output
|
||||||
|
if 'input_partition' not in DAG:
|
||||||
|
DAG['input_partition'] = {}
|
||||||
|
DAG['input_partition'][partition.name] = DAG_node
|
||||||
|
|
||||||
|
for partition in partitions:
|
||||||
|
DAG_node = {'input': {}, 'output': {}}
|
||||||
|
DAG_node['input'], DAG_node['output'], DAG_node['output_len'] = get_partition_depends(partition, partitions, input_partitions, output_partitions)
|
||||||
|
DAG[partition.name] = DAG_node
|
||||||
|
|
||||||
|
for partition in output_partitions:
|
||||||
|
DAG_node = {'input': {}, 'output': {}, 'output_len': 0}
|
||||||
|
DAG_node['input'] = torch.fx.graph.map_arg(partition.args[0], lambda n: find_def_in_partition(n, partitions, input_partitions))
|
||||||
|
DAG['output_partition'] = DAG_node
|
||||||
|
|
||||||
|
return DAG
|
85
tests/test_fx/test_pipeline/test_DAG/dag_utils.py
Normal file
85
tests/test_fx/test_pipeline/test_DAG/dag_utils.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import torch
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
MANUAL_SEED = 0
|
||||||
|
random.seed(MANUAL_SEED)
|
||||||
|
np.random.seed(MANUAL_SEED)
|
||||||
|
torch.manual_seed(MANUAL_SEED)
|
||||||
|
|
||||||
|
def split_model_and_get_DAG(model, data_gen):
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# generate input sample
|
||||||
|
kwargs = data_gen()
|
||||||
|
|
||||||
|
# get origin output and rng state
|
||||||
|
cpu_rng_state = torch.get_rng_state()
|
||||||
|
output = model(**kwargs)
|
||||||
|
|
||||||
|
# tracing model
|
||||||
|
tracer = ColoTracer()
|
||||||
|
try:
|
||||||
|
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||||
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
# apply transform passes
|
||||||
|
annotated_model = balanced_split_pass(gm, 2)
|
||||||
|
top_module, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||||
|
|
||||||
|
return top_module, split_submodules[0]._DAG
|
||||||
|
|
||||||
|
def check_input(input, input_node, top_module):
|
||||||
|
for user in input_node.users.keys():
|
||||||
|
partition_name = user.name
|
||||||
|
assert partition_name in input['output']
|
||||||
|
|
||||||
|
def check_submod(submod_partition, node, top_module):
|
||||||
|
for arg in node.args:
|
||||||
|
input_part_name = None
|
||||||
|
if arg.op == 'placeholder':
|
||||||
|
input_part_name = 'MODEL_INPUT'
|
||||||
|
elif not arg.name.startswith('getitem'):
|
||||||
|
input_part_name = arg.name
|
||||||
|
else:
|
||||||
|
input_part_name = arg.args[0].name
|
||||||
|
assert input_part_name in submod_partition['input']
|
||||||
|
|
||||||
|
for user in node.users:
|
||||||
|
output_part_names = []
|
||||||
|
if user.op == 'output':
|
||||||
|
output_part_names.append('MODEL_OUTPUT')
|
||||||
|
elif not user.name.startswith('getitem'):
|
||||||
|
output_part_names.append(user.name)
|
||||||
|
else:
|
||||||
|
for n in user.users:
|
||||||
|
if n.op == 'output':
|
||||||
|
output_part_names.append('MODEL_OUTPUT')
|
||||||
|
else:
|
||||||
|
output_part_names.append(n.name)
|
||||||
|
|
||||||
|
for output_part_name in output_part_names:
|
||||||
|
assert output_part_name in submod_partition['output']
|
||||||
|
|
||||||
|
def check_DAG(top_module, DAG):
|
||||||
|
assert 'input_partition' in DAG
|
||||||
|
input_partition = DAG['input_partition']
|
||||||
|
|
||||||
|
for node in top_module.graph.nodes:
|
||||||
|
# check input
|
||||||
|
if node.op == 'placeholder':
|
||||||
|
assert node.name in input_partition
|
||||||
|
input = input_partition[node.name]
|
||||||
|
check_input(input, node, top_module)
|
||||||
|
elif node.op == 'call_module':
|
||||||
|
assert node.name in DAG
|
||||||
|
submod_partition = DAG[node.name]
|
||||||
|
check_submod(submod_partition, node, top_module)
|
||||||
|
|
31
tests/test_fx/test_pipeline/test_DAG/test_dag.py
Normal file
31
tests/test_fx/test_pipeline/test_DAG/test_dag.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from dag_utils import split_model_and_get_DAG, check_DAG
|
||||||
|
|
||||||
|
BATCH_SIZE = 1
|
||||||
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip('balance split v2 is not ready')
|
||||||
|
def test_opt():
|
||||||
|
MODEL_LIST = [
|
||||||
|
transformers.OPTModel,
|
||||||
|
#transformers.OPTForCausalLM,
|
||||||
|
]
|
||||||
|
|
||||||
|
config = transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4)
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
model = model_cls(config=config)
|
||||||
|
top_mod, DAG = split_model_and_get_DAG(model, data_gen)
|
||||||
|
check_DAG(top_mod, DAG)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_opt()
|
Loading…
Reference in New Issue
Block a user