[checkpointio] sharded optimizer checkpoint for DDP plugin (#4002)

This commit is contained in:
Baizhou Zhang
2023-06-16 14:14:05 +08:00
committed by GitHub
parent 725af3eeeb
commit 822c3d4d66
6 changed files with 79 additions and 34 deletions

View File

@@ -13,7 +13,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad
@parameterize('shard', [True, False])
def check_torch_ddp_checkpointIO(shard: bool):
@parameterize('size_per_shard', [16, 128])
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
@@ -38,11 +39,9 @@ def check_torch_ddp_checkpointIO(shard: bool):
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
booster.save_model(model, model_ckpt_path, shard=shard)
if not shard:
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
booster.save_optimizer(optimizer, optimizer_ckpt_path)
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
dist.barrier()
new_model = resnet18()
@@ -55,11 +54,10 @@ def check_torch_ddp_checkpointIO(shard: bool):
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
if not shard:
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
def run_dist(rank, world_size, port):