[fx] added unit test for coloproxy (#1119)

* [fx] added unit test for coloproxy

* polish code

* polish code
This commit is contained in:
Frank Lee
2022-06-15 15:27:51 +08:00
committed by GitHub
parent 7d14b473f0
commit 16302a5359
2 changed files with 40 additions and 4 deletions

View File

@@ -0,0 +1,23 @@
import torch
from colossalai.fx.proxy import ColoProxy
def test_coloproxy():
# create a dummy node only for testing purpose
model = torch.nn.Linear(10, 10)
gm = torch.fx.symbolic_trace(model)
node = list(gm.graph.nodes)[0]
# create proxy
proxy = ColoProxy(node=node)
proxy.meta_tensor = torch.empty(4, 2, device='meta')
assert len(proxy) == 4
assert proxy.shape[0] == 4 and proxy.shape[1] == 2
assert proxy.dim() == 2
assert proxy.dtype == torch.float32
assert proxy.size(0) == 4
if __name__ == '__main__':
test_coloproxy()