mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 11:31:58 +00:00
[checkpointio] fix checkpoint for 3d (#6187)
* fix checkpoint io for 3d * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update hybrid_parallel_checkpoint_io.py * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
2b415e5999
commit
5c09d726a6
@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
@ -10,6 +11,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
@ -37,7 +39,6 @@ from .utils import (
|
||||
load_shard_state_dict,
|
||||
load_state_dict,
|
||||
load_state_dict_into_model,
|
||||
load_states_into_optimizer,
|
||||
save_config_file,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
@ -724,26 +725,37 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||
self.load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
device = param.device
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.shard_from_complete_optimizer_state(
|
||||
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
|
||||
)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
def load_states_into_optimizer(self, optimizer: Optimizer, state_dict: dict, id_map: dict):
|
||||
state_dict = {int(k): v for k, v in state_dict.items()}
|
||||
new_states = defaultdict(dict)
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
for k, state in state_dict.items():
|
||||
if k in id_map:
|
||||
param = id_map[k]
|
||||
device = param.device
|
||||
dtype = param.dtype
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
new_states[param] = self.shard_from_complete_optimizer_state(
|
||||
state,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
inplace=True,
|
||||
)
|
||||
optimizer.optim.state.update(new_states)
|
||||
|
||||
def save_unsharded_model(
|
||||
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||
):
|
||||
@ -988,22 +1000,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
id_map[param_id] = param
|
||||
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
device = param.device
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.shard_from_complete_optimizer_state(
|
||||
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
|
||||
)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
self.load_states_into_optimizer(optimizer, state_dict["state"], id_map)
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
|
||||
@ -1086,6 +1083,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
current_shape: torch.Size,
|
||||
original_shape: torch.Size,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
inplace: bool,
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
@ -1135,7 +1133,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
slice_size = v.numel() // self.global_dp_size
|
||||
v = v.split(slice_size, dim=0)[self.dp_rank]
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
state_[k] = v.detach().clone().to(device=device, dtype=dtype)
|
||||
|
||||
return state_
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user