mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[graph] improve the graph building. (#1157)
This commit is contained in:
@@ -1,84 +0,0 @@
|
||||
import pytest
|
||||
from torch import nn
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.tensor.graph import GraphContext
|
||||
import gc
|
||||
|
||||
|
||||
class SimpleNet(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.proj1 = nn.Linear(4, 8)
|
||||
self.proj2 = nn.Linear(8, 4)
|
||||
self.proj3 = nn.Linear(4, 4)
|
||||
self.proj4 = nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj1(x)
|
||||
x = self.proj2(x)
|
||||
x = self.proj3(x)
|
||||
x = self.proj4(x)
|
||||
return x
|
||||
|
||||
|
||||
def _visit_graph(start_node):
|
||||
if start_node is None:
|
||||
return
|
||||
|
||||
start_node.print()
|
||||
|
||||
post_node_list = start_node.post_nodes
|
||||
for node in post_node_list:
|
||||
_visit_graph(node)
|
||||
|
||||
|
||||
def _get_tensors():
|
||||
for obj in gc.get_objects():
|
||||
try:
|
||||
if torch.is_tensor(obj):
|
||||
yield obj
|
||||
except Exception as e:
|
||||
print('A trivial exception occured: {}'.format(e))
|
||||
|
||||
|
||||
def _count_tensors():
|
||||
cnt = 0
|
||||
for t in _get_tensors():
|
||||
cnt += 1
|
||||
return cnt
|
||||
|
||||
|
||||
def count_tensors(use_colossal):
|
||||
model = SimpleNet()
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
if use_colossal:
|
||||
colo_input = ColoTensor.from_torch_tensor(torch.randn(4))
|
||||
graph_ctx = GraphContext()
|
||||
with graph_ctx:
|
||||
output = model(colo_input)
|
||||
output = model(colo_input)
|
||||
ret = _count_tensors()
|
||||
|
||||
_visit_graph(graph_ctx.graph_nodes[0])
|
||||
|
||||
del graph_ctx
|
||||
return ret
|
||||
else:
|
||||
input_t = torch.randn(4)
|
||||
output = model(input_t)
|
||||
output = model(input_t)
|
||||
return _count_tensors()
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
# FIXME(ver217)
|
||||
def test_check_activation_tensors():
|
||||
assert count_tensors(False) == count_tensors(True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
count_tensors(True)
|
Reference in New Issue
Block a user