From f4f3d529242babc7ecb9401efe6c8a7df4bf0c56 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 18 Nov 2024 17:51:15 +0800 Subject: [PATCH] fix --- .../test_gemini_checkpoint_io.py | 28 ++----------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 6462f65c2..419df5110 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -86,17 +86,12 @@ def exam_state_dict_with_origin( @clear_cache_before_run() @parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS) -@parameterize("shard", [False]) +@parameterize("shard", [False, True]) @parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("size_per_shard", [32]) @parameterize("tp_size", [1, 2]) @parameterize("zero_size", [2]) -@parameterize( - "use_async", - [ - True, - ], -) +@parameterize("use_async", [False, True]) def exam_state_dict( placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool ): @@ -137,23 +132,6 @@ def exam_state_dict( for group in optimizer.param_groups: group["lr"] = 0.1 - """output_dir = "./checkpoints" - import os - os.makedirs(output_dir, exist_ok=True) - model_ckpt_path = f"{output_dir}/model" - optimizer_ckpt_path = f"{output_dir}/optimizer" - if not shard: - model_ckpt_path = f"{model_ckpt_path}.safetensors" - print("model_ckpt_path", model_ckpt_path) - booster.save_model( - model, - model_ckpt_path, - shard=shard, - size_per_shard=size_per_shard, - use_async=use_async - ) - booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)""" - with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" @@ -214,7 +192,7 @@ def run_dist(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() exam_state_dict_with_origin() - # exam_lazy_from_pretrained() + exam_lazy_from_pretrained() @pytest.mark.dist