diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index b7900bc0f..5f832f13c 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -1,7 +1,7 @@ import copy -from functools import reduce import logging import os +from functools import reduce from pathlib import Path from shutil import rmtree from typing import Dict, Iterator, Optional, OrderedDict, Tuple @@ -445,7 +445,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): # Store param groups. index_file.append_meta_data("param_groups", param_group_file) group_file_path = os.path.join(checkpoint, param_group_file) - save_param_groups(optimizer.param_info, group_file_path) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + save_param_groups({"param_groups": param_groups}, group_file_path) # Store index file. index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) @@ -504,7 +508,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): # Store param groups. final_index_file.append_meta_data("param_groups", param_group_file) group_file_path = os.path.join(checkpoint, param_group_file) - save_param_groups(optimizer.param_info, group_file_path) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + save_param_groups({"param_groups": param_groups}, group_file_path) final_index_file.write_index_file(final_index_file_path) rmtree(tmp_index_file_folder) @@ -718,7 +726,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if self.pp_size == 1: # When pipeline is not used, let master rank directly save the collected state_dict. - state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states} + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + state_dict = {"param_groups": param_groups, "state": local_states} if self.coordinator.is_master(): save_state_dict(state_dict, checkpoint, use_safetensors=False) else: @@ -729,7 +741,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): # Only the master rank do the saving. if self.coordinator.is_master(): - state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()} + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + state_dict = {"param_groups": param_groups, "state": dict()} for _states in states_list: state_dict["state"].update(_states) save_state_dict(state_dict, checkpoint, use_safetensors=False) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 98fbb0c50..18367af59 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -621,7 +621,10 @@ class GeminiOptimizer(OptimizerWrapper): Return the param_groups in Pytorch format when saving to checkpoint. """ - param_groups = copy.deepcopy(self.param_groups_backup) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(self.optim.param_groups, self.param_groups_backup) + ] # To be compatible with pytorch checkpointing, # store extra hyperparameters used by pytorch Adam optimizer. diff --git a/extensions/cuda_extension.py b/extensions/cuda_extension.py index b5e8a285b..842cd9713 100644 --- a/extensions/cuda_extension.py +++ b/extensions/cuda_extension.py @@ -1,7 +1,10 @@ import os +import time from abc import abstractmethod +from pathlib import Path from typing import List +from .base_extension import _Extension from .cpp_extension import _CppExtension from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 708a1906b..61cac1d83 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -97,7 +97,7 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha new_model = model_fn() optimizer = HybridAdam(model.parameters(), lr=0.001) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) + new_optimizer = HybridAdam(new_model.parameters(), lr=0.01) new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) data = data_gen_fn() @@ -109,6 +109,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha booster.backward(loss, optimizer) optimizer.step() + for group in optimizer.param_groups: + group["lr"] = 0.1 with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" @@ -127,6 +129,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha check_state_dict_equal( optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False ) + for group in new_optimizer.param_groups: + assert group["lr"] == 0.1 # Check the new model/optimizer can successfully run. data = data_gen_fn() diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index a42b550cd..b5cb31715 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -83,7 +83,8 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf optimizer.backward(loss) optimizer.step() - + for group in optimizer.param_groups: + group["lr"] = 0.1 with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer"