mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[context] support lazy init of module (#1088)
* [context] support lazy init of module * polish code
This commit is contained in:
23
tests/test_utils/test_lazy_init_ctx.py
Normal file
23
tests/test_utils/test_lazy_init_ctx.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.utils.model.lazy_init_context import LazyInitContext
|
||||
|
||||
def test_lazy_init_ctx():
|
||||
|
||||
with LazyInitContext() as ctx:
|
||||
model = nn.Linear(10, 10)
|
||||
model.weight.zero_()
|
||||
|
||||
# make sure the weight is a meta tensor
|
||||
assert model.weight.is_meta
|
||||
|
||||
# initialize weights
|
||||
ctx.lazy_init_parameters(model)
|
||||
|
||||
# make sure the weight is not a meta tensor
|
||||
# and initialized correctly
|
||||
assert not model.weight.is_meta and torch.all(model.weight == 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_lazy_init_ctx()
|
Reference in New Issue
Block a user