mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[fx] methods to get fx graph property. (#1246)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* manipulation
* [fx]add graph manipulation methods.
* [fx]methods to get fx graph property.
* add unit test
* add docstring to explain top node and leaf node in this context
This commit is contained in:
50
tests/test_fx/test_graph_manipulation.py
Normal file
50
tests/test_fx/test_graph_manipulation.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import colossalai
|
||||
import torch
|
||||
from colossalai.fx.passes.utils import get_leaf, get_top, assign_bfs_level_to_nodes
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(dim, dim)
|
||||
self.linear2 = torch.nn.Linear(dim, dim)
|
||||
self.linear3 = torch.nn.Linear(dim, dim)
|
||||
self.linear4 = torch.nn.Linear(dim, dim)
|
||||
self.linear5 = torch.nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
l1 = self.linear1(x)
|
||||
l2 = self.linear2(x)
|
||||
l3 = self.linear3(l1)
|
||||
l4 = self.linear4(l2)
|
||||
l5 = self.linear5(l3)
|
||||
return l4, l5
|
||||
|
||||
|
||||
def test_graph_manipulation():
|
||||
model = MLP(4)
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model)
|
||||
nodes = list(graph.nodes)
|
||||
x, l1, l2, l3, l4, l5, output = nodes
|
||||
|
||||
leaf_nodes = set(get_leaf(graph))
|
||||
top_nodes = set(get_top(graph))
|
||||
compare_dict = {x: None, l1: 0, l2: 0, l3: 1, l4: 1, l5: 2, output: None}
|
||||
assign_bfs_level_to_nodes(graph)
|
||||
|
||||
assert leaf_nodes == set([l4, l5])
|
||||
assert top_nodes == set([l1, l2])
|
||||
for node in graph.nodes:
|
||||
if node.op in ('placeholder', 'output'):
|
||||
assert not hasattr(node, 'bfs_level')
|
||||
else:
|
||||
assert node.bfs_level == compare_dict[node]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_graph_manipulation()
|
Reference in New Issue
Block a user