This commit is contained in:
wangbluo 2024-11-19 17:32:20 +08:00
parent 65c487499e
commit 2aa6e44355

View File

@ -13,8 +13,9 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad
@parameterize("shard", [True, False])
@parameterize("use_async", [True, False])
@parameterize("size_per_shard", [16, 128])
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
def check_torch_ddp_checkpointIO(shard: bool, use_async: bool, size_per_shard: int):
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
@ -39,7 +40,11 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
if not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
dist.barrier()