colo init context add device attr. (#866)

This commit is contained in:
Jiarui Fang
2022-04-25 14:24:26 +08:00
committed by GitHub
parent 2238758c2e
commit d01d3b8cb0
3 changed files with 36 additions and 12 deletions

View File

@@ -5,17 +5,16 @@ import torch
from colossalai.tensor import ColoTensor
from copy import deepcopy
from colossalai.utils.cuda import get_current_device
def test_linear():
def test_lazy_init():
in_dim = 4
out_dim = 5
with ColoInitContext(lazy_memory_allocate=True) as ctx:
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
print(fc.weight.numel())
print(fc.bias.numel())
# lazy_memory_allocate=True, no payload is maintained
assert fc.weight._torch_tensor.numel() == 0
@@ -23,5 +22,18 @@ def test_linear():
assert fc.weight._torch_tensor.numel() == in_dim * out_dim
def test_device():
in_dim = 4
out_dim = 5
with ColoInitContext(lazy_memory_allocate=True, device=get_current_device()) as ctx:
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
# eval an lazy parameter
fc.weight.torch_tensor()
assert fc.weight.device == get_current_device()
if __name__ == '__main__':
test_linear()
test_lazy_init()
test_device()