[context]use meta tensor to init model lazily. (#1187)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [context]use meta tensor to init model lazily.

* polish

* make module with device kwargs bypass the normal init.

* change unit test to adapt updated context.
This commit is contained in:
YuliangLiu0306
2022-06-29 21:02:30 +08:00
committed by GitHub
parent 2c8c05675d
commit 2053e138a2
3 changed files with 77 additions and 67 deletions

View File

@@ -1,23 +1,22 @@
import torch
import torch.nn as nn
from colossalai.utils.model.lazy_init_context import LazyInitContext
from torchvision.models import resnet34
def test_lazy_init_ctx():
with LazyInitContext() as ctx:
model = nn.Linear(10, 10)
model.weight.zero_()
# make sure the weight is a meta tensor
assert model.weight.is_meta
# initialize weights
def test_lazy_init():
ctx = LazyInitContext()
with ctx:
model = resnet34(num_classes=10)
for param in model.parameters():
assert param.is_meta
for buffer in model.buffers():
assert buffer.is_meta
ctx.lazy_init_parameters(model)
# make sure the weight is not a meta tensor
# and initialized correctly
assert not model.weight.is_meta and torch.all(model.weight == 0)
for param in model.parameters():
assert not param.is_meta
for buffer in model.buffers():
assert not buffer.is_meta
if __name__ == '__main__':
test_lazy_init_ctx()
test_lazy_init()