diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 419df5110..7eb4c1684 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -135,9 +135,7 @@ def exam_state_dict( with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" - if not use_async: - model_ckpt_path = f"{model_ckpt_path}.pt" - if use_async: + if not shard and use_async: model_ckpt_path = f"{model_ckpt_path}.safetensors" booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index ea3e2aacd..7d6dc69a9 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -87,9 +87,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" - if not use_async: - model_ckpt_path = f"{model_ckpt_path}.pt" - if use_async: + if not shard and use_async: model_ckpt_path = f"{model_ckpt_path}.safetensors" booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 05dfcce4f..f736708db 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -47,8 +47,6 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" - if not shard and not use_async: - model_ckpt_path = f"{model_ckpt_path}.pt" if not shard and use_async: model_ckpt_path = f"{model_ckpt_path}.safetensors" if not shard and use_async: