diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py index 87d35f252..8a5310993 100644 --- a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -13,8 +13,9 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad @parameterize("shard", [True, False]) +@parameterize("use_async", [True, False]) @parameterize("size_per_shard", [16, 128]) -def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): +def check_torch_ddp_checkpointIO(shard: bool, use_async: bool, size_per_shard: int): plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) model = resnet18() @@ -39,7 +40,11 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler" - booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + if not use_async: + model_ckpt_path = f"{model_ckpt_path}.pt" + if use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path) dist.barrier()