mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[context]use meta tensor to init model lazily. (#1187)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [context]use meta tensor to init model lazily.
* polish
* make module with device kwargs bypass the normal init.
* change unit test to adapt updated context.
This commit is contained in:
@@ -1,23 +1,22 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.utils.model.lazy_init_context import LazyInitContext
|
||||
from torchvision.models import resnet34
|
||||
|
||||
def test_lazy_init_ctx():
|
||||
|
||||
with LazyInitContext() as ctx:
|
||||
model = nn.Linear(10, 10)
|
||||
model.weight.zero_()
|
||||
|
||||
# make sure the weight is a meta tensor
|
||||
assert model.weight.is_meta
|
||||
|
||||
# initialize weights
|
||||
def test_lazy_init():
|
||||
ctx = LazyInitContext()
|
||||
with ctx:
|
||||
model = resnet34(num_classes=10)
|
||||
for param in model.parameters():
|
||||
assert param.is_meta
|
||||
for buffer in model.buffers():
|
||||
assert buffer.is_meta
|
||||
ctx.lazy_init_parameters(model)
|
||||
|
||||
# make sure the weight is not a meta tensor
|
||||
# and initialized correctly
|
||||
assert not model.weight.is_meta and torch.all(model.weight == 0)
|
||||
for param in model.parameters():
|
||||
assert not param.is_meta
|
||||
for buffer in model.buffers():
|
||||
assert not buffer.is_meta
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_lazy_init_ctx()
|
||||
test_lazy_init()
|
||||
|
Reference in New Issue
Block a user