[booster] implemented the torch ddd + resnet example (#3232)

* [booster] implemented the torch ddd + resnet example

* polish code
This commit is contained in:
Frank Lee
2023-03-27 10:24:14 +08:00
committed by GitHub
parent 1a229045af
commit 73d3e4d309
22 changed files with 608 additions and 128 deletions

View File

@@ -42,8 +42,8 @@ def test_unsharded_checkpoint():
new_optimizer = Adam(new_model.parameters(), lr=0.001)
# load the model and optimizer
new_model = ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
new_optimizer = ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
# do recursive check for the optimizer state dict
# if the value is a dict, compare its values