From 5c09d726a6d72d73607d8a8ae35fbf3527714464 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 12 Feb 2025 11:54:55 +0800 Subject: [PATCH] [checkpointio] fix checkpoint for 3d (#6187) * fix checkpoint io for 3d * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update hybrid_parallel_checkpoint_io.py * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../hybrid_parallel_checkpoint_io.py | 62 +++++++++---------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 1b7ae1888..bd814f426 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -1,6 +1,7 @@ import copy import logging import os +from collections import defaultdict from functools import reduce from pathlib import Path from shutil import rmtree @@ -10,6 +11,7 @@ import torch import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup +from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map @@ -37,7 +39,6 @@ from .utils import ( load_shard_state_dict, load_state_dict, load_state_dict_into_model, - load_states_into_optimizer, save_config_file, save_param_groups, save_state_dict, @@ -724,26 +725,37 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) if not low_cpu_mem_mode: state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) - load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) + self.load_states_into_optimizer(optimizer, state_dict, id_map) loaded_file.add(filename) - # Then shard the loaded optimizer states if using tp/zero. - for param, state in optimizer.optim.state.items(): - device = param.device - if master_to_working_map is not None: - working_param = master_to_working_map[id(param)] - else: - working_param = param - original_shape = optimizer.param_info["param2shape"][id(working_param)] - sharded_state = self.shard_from_complete_optimizer_state( - state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True - ) - optimizer.optim.state[param] = sharded_state - sharded_optimizer_loading_epilogue(optimizer.optim) if self.verbose and self.coordinator.is_master(): logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + def load_states_into_optimizer(self, optimizer: Optimizer, state_dict: dict, id_map: dict): + state_dict = {int(k): v for k, v in state_dict.items()} + new_states = defaultdict(dict) + master_to_working_map = optimizer.get_master_to_working_map() + for k, state in state_dict.items(): + if k in id_map: + param = id_map[k] + device = param.device + dtype = param.dtype + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + new_states[param] = self.shard_from_complete_optimizer_state( + state, + current_shape=working_param.shape, + original_shape=original_shape, + device=device, + dtype=dtype, + inplace=True, + ) + optimizer.optim.state.update(new_states) + def save_unsharded_model( self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False ): @@ -988,22 +1000,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): for param in pg["params"]: param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) id_map[param_id] = param - load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True) - - # Then shard the loaded optimizer states if using tp/zero. - for param, state in optimizer.optim.state.items(): - if param is None: - continue - device = param.device - if master_to_working_map is not None: - working_param = master_to_working_map[id(param)] - else: - working_param = param - original_shape = optimizer.param_info["param2shape"][id(working_param)] - sharded_state = self.shard_from_complete_optimizer_state( - state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True - ) - optimizer.optim.state[param] = sharded_state + self.load_states_into_optimizer(optimizer, state_dict["state"], id_map) sharded_optimizer_loading_epilogue(optimizer.optim) @@ -1086,6 +1083,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): current_shape: torch.Size, original_shape: torch.Size, device: torch.device, + dtype: torch.dtype, inplace: bool, ) -> OrderedDict: """ @@ -1135,7 +1133,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): slice_size = v.numel() // self.global_dp_size v = v.split(slice_size, dim=0)[self.dp_rank] - state_[k] = v.detach().clone().to(device) + state_[k] = v.detach().clone().to(device=device, dtype=dtype) return state_