mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
colo init context add device attr. (#866)
This commit is contained in:
@@ -5,17 +5,16 @@ import torch
|
||||
from colossalai.tensor import ColoTensor
|
||||
from copy import deepcopy
|
||||
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
def test_linear():
|
||||
|
||||
def test_lazy_init():
|
||||
in_dim = 4
|
||||
out_dim = 5
|
||||
|
||||
with ColoInitContext(lazy_memory_allocate=True) as ctx:
|
||||
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
|
||||
|
||||
print(fc.weight.numel())
|
||||
print(fc.bias.numel())
|
||||
|
||||
# lazy_memory_allocate=True, no payload is maintained
|
||||
assert fc.weight._torch_tensor.numel() == 0
|
||||
|
||||
@@ -23,5 +22,18 @@ def test_linear():
|
||||
assert fc.weight._torch_tensor.numel() == in_dim * out_dim
|
||||
|
||||
|
||||
def test_device():
|
||||
in_dim = 4
|
||||
out_dim = 5
|
||||
|
||||
with ColoInitContext(lazy_memory_allocate=True, device=get_current_device()) as ctx:
|
||||
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
|
||||
|
||||
# eval an lazy parameter
|
||||
fc.weight.torch_tensor()
|
||||
assert fc.weight.device == get_current_device()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_linear()
|
||||
test_lazy_init()
|
||||
test_device()
|
||||
|
Reference in New Issue
Block a user