mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[lr-scheduler] fix load state dict and add test (#5369)
This commit is contained in:
20
tests/test_optimizer/test_lr_scheduler.py
Normal file
20
tests/test_optimizer/test_lr_scheduler.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
|
||||
|
||||
def test_lr_scheduler_save_load():
|
||||
model = nn.Linear(10, 10)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2)
|
||||
new_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2)
|
||||
for _ in range(5):
|
||||
scheduler.step()
|
||||
state_dict = scheduler.state_dict()
|
||||
new_scheduler.load_state_dict(state_dict)
|
||||
assert state_dict == new_scheduler.state_dict()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lr_scheduler_save_load()
|
Reference in New Issue
Block a user