[checkpoint] refactored the API and added safetensors support (#3427)

* [checkpoint] refactored the API and added safetensors support

* polish code
This commit is contained in:
Frank Lee
2023-04-04 15:23:01 +08:00
committed by GitHub
parent 26b7aac0be
commit 1beb85cc25
9 changed files with 579 additions and 280 deletions

View File

@@ -71,6 +71,29 @@ def check_dataloader_sharding():
batch_to_compare), 'Same number was found across ranks but expected it to be different'
def check_checkpoint_save_and_load():
model_fn, data_gen_fn, output_transform_fn, _ = model_zoo['timm_resnet']
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = model_fn()
optimizer = SGD(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
def run_dist(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')

View File

@@ -1,5 +1,6 @@
import tempfile
import pytest
import torch
from torch.optim import Adam
from torchvision.models import resnet18
@@ -14,7 +15,8 @@ from colossalai.checkpoint_io import GeneralCheckpointIO
# ========
def test_unsharded_checkpoint():
@pytest.mark.parametrize('use_safetensors', [True, False])
def test_unsharded_checkpoint(use_safetensors: bool):
# create a model and optimizer
model = resnet18()
optimizer = Adam(model.parameters(), lr=0.001)
@@ -29,12 +31,16 @@ def test_unsharded_checkpoint():
optimizer.step()
# create a temp file for checkpoint
model_ckpt_tempfile = tempfile.NamedTemporaryFile()
if use_safetensors:
suffix = ".safetensors"
else:
suffix = ".bin"
model_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix)
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
# save the model and optimizer
ckpt_io = GeneralCheckpointIO()
ckpt_io.save_model(model, model_ckpt_tempfile.name)
ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)
# create new model
@@ -68,3 +74,4 @@ def test_unsharded_checkpoint():
# check for model and optimizer state dict recursively
recursive_check(model.state_dict(), new_model.state_dict())
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())