[checkpoint] refactored the API and added safetensors support (#3427)

* [checkpoint] refactored the API and added safetensors support

* polish code
This commit is contained in:
Frank Lee
2023-04-04 15:23:01 +08:00
committed by GitHub
parent 26b7aac0be
commit 1beb85cc25
9 changed files with 579 additions and 280 deletions

View File

@@ -1,5 +1,6 @@
import tempfile
import pytest
import torch
from torch.optim import Adam
from torchvision.models import resnet18
@@ -14,7 +15,8 @@ from colossalai.checkpoint_io import GeneralCheckpointIO
# ========
def test_unsharded_checkpoint():
@pytest.mark.parametrize('use_safetensors', [True, False])
def test_unsharded_checkpoint(use_safetensors: bool):
# create a model and optimizer
model = resnet18()
optimizer = Adam(model.parameters(), lr=0.001)
@@ -29,12 +31,16 @@ def test_unsharded_checkpoint():
optimizer.step()
# create a temp file for checkpoint
model_ckpt_tempfile = tempfile.NamedTemporaryFile()
if use_safetensors:
suffix = ".safetensors"
else:
suffix = ".bin"
model_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix)
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
# save the model and optimizer
ckpt_io = GeneralCheckpointIO()
ckpt_io.save_model(model, model_ckpt_tempfile.name)
ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)
# create new model
@@ -68,3 +74,4 @@ def test_unsharded_checkpoint():
# check for model and optimizer state dict recursively
recursive_check(model.state_dict(), new_model.state_dict())
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())