Files
ColossalAI/tests/test_fx/test_coloproxy.py
Frank Lee 16302a5359 [fx] added unit test for coloproxy (#1119)
* [fx] added unit test for coloproxy

* polish code

* polish code
2022-06-15 15:27:51 +08:00

23 lines
581 B
Python

import torch
from colossalai.fx.proxy import ColoProxy
def test_coloproxy():
# create a dummy node only for testing purpose
model = torch.nn.Linear(10, 10)
gm = torch.fx.symbolic_trace(model)
node = list(gm.graph.nodes)[0]
# create proxy
proxy = ColoProxy(node=node)
proxy.meta_tensor = torch.empty(4, 2, device='meta')
assert len(proxy) == 4
assert proxy.shape[0] == 4 and proxy.shape[1] == 2
assert proxy.dim() == 2
assert proxy.dtype == torch.float32
assert proxy.size(0) == 4
if __name__ == '__main__':
test_coloproxy()