mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[fx] added unit test for coloproxy (#1119)
* [fx] added unit test for coloproxy * polish code * polish code
This commit is contained in:
23
tests/test_fx/test_coloproxy.py
Normal file
23
tests/test_fx/test_coloproxy.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
|
||||
|
||||
def test_coloproxy():
|
||||
# create a dummy node only for testing purpose
|
||||
model = torch.nn.Linear(10, 10)
|
||||
gm = torch.fx.symbolic_trace(model)
|
||||
node = list(gm.graph.nodes)[0]
|
||||
|
||||
# create proxy
|
||||
proxy = ColoProxy(node=node)
|
||||
proxy.meta_tensor = torch.empty(4, 2, device='meta')
|
||||
|
||||
assert len(proxy) == 4
|
||||
assert proxy.shape[0] == 4 and proxy.shape[1] == 2
|
||||
assert proxy.dim() == 2
|
||||
assert proxy.dtype == torch.float32
|
||||
assert proxy.size(0) == 4
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_coloproxy()
|
Reference in New Issue
Block a user