mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
[checkpointio] fix checkpoint for 3d (#6187)
* fix checkpoint io for 3d * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update hybrid_parallel_checkpoint_io.py * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
2b415e5999
commit
5c09d726a6
@ -1,6 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
@ -10,6 +11,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
@ -37,7 +39,6 @@ from .utils import (
|
|||||||
load_shard_state_dict,
|
load_shard_state_dict,
|
||||||
load_state_dict,
|
load_state_dict,
|
||||||
load_state_dict_into_model,
|
load_state_dict_into_model,
|
||||||
load_states_into_optimizer,
|
|
||||||
save_config_file,
|
save_config_file,
|
||||||
save_param_groups,
|
save_param_groups,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
@ -724,26 +725,37 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||||
if not low_cpu_mem_mode:
|
if not low_cpu_mem_mode:
|
||||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
self.load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||||
loaded_file.add(filename)
|
loaded_file.add(filename)
|
||||||
|
|
||||||
# Then shard the loaded optimizer states if using tp/zero.
|
|
||||||
for param, state in optimizer.optim.state.items():
|
|
||||||
device = param.device
|
|
||||||
if master_to_working_map is not None:
|
|
||||||
working_param = master_to_working_map[id(param)]
|
|
||||||
else:
|
|
||||||
working_param = param
|
|
||||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
|
||||||
sharded_state = self.shard_from_complete_optimizer_state(
|
|
||||||
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
|
|
||||||
)
|
|
||||||
optimizer.optim.state[param] = sharded_state
|
|
||||||
|
|
||||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||||
if self.verbose and self.coordinator.is_master():
|
if self.verbose and self.coordinator.is_master():
|
||||||
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||||
|
|
||||||
|
def load_states_into_optimizer(self, optimizer: Optimizer, state_dict: dict, id_map: dict):
|
||||||
|
state_dict = {int(k): v for k, v in state_dict.items()}
|
||||||
|
new_states = defaultdict(dict)
|
||||||
|
master_to_working_map = optimizer.get_master_to_working_map()
|
||||||
|
for k, state in state_dict.items():
|
||||||
|
if k in id_map:
|
||||||
|
param = id_map[k]
|
||||||
|
device = param.device
|
||||||
|
dtype = param.dtype
|
||||||
|
if master_to_working_map is not None:
|
||||||
|
working_param = master_to_working_map[id(param)]
|
||||||
|
else:
|
||||||
|
working_param = param
|
||||||
|
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||||
|
new_states[param] = self.shard_from_complete_optimizer_state(
|
||||||
|
state,
|
||||||
|
current_shape=working_param.shape,
|
||||||
|
original_shape=original_shape,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
optimizer.optim.state.update(new_states)
|
||||||
|
|
||||||
def save_unsharded_model(
|
def save_unsharded_model(
|
||||||
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||||
):
|
):
|
||||||
@ -988,22 +1000,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
for param in pg["params"]:
|
for param in pg["params"]:
|
||||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||||
id_map[param_id] = param
|
id_map[param_id] = param
|
||||||
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
|
self.load_states_into_optimizer(optimizer, state_dict["state"], id_map)
|
||||||
|
|
||||||
# Then shard the loaded optimizer states if using tp/zero.
|
|
||||||
for param, state in optimizer.optim.state.items():
|
|
||||||
if param is None:
|
|
||||||
continue
|
|
||||||
device = param.device
|
|
||||||
if master_to_working_map is not None:
|
|
||||||
working_param = master_to_working_map[id(param)]
|
|
||||||
else:
|
|
||||||
working_param = param
|
|
||||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
|
||||||
sharded_state = self.shard_from_complete_optimizer_state(
|
|
||||||
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
|
|
||||||
)
|
|
||||||
optimizer.optim.state[param] = sharded_state
|
|
||||||
|
|
||||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||||
|
|
||||||
@ -1086,6 +1083,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
current_shape: torch.Size,
|
current_shape: torch.Size,
|
||||||
original_shape: torch.Size,
|
original_shape: torch.Size,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
inplace: bool,
|
inplace: bool,
|
||||||
) -> OrderedDict:
|
) -> OrderedDict:
|
||||||
"""
|
"""
|
||||||
@ -1135,7 +1133,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
slice_size = v.numel() // self.global_dp_size
|
slice_size = v.numel() // self.global_dp_size
|
||||||
v = v.split(slice_size, dim=0)[self.dp_rank]
|
v = v.split(slice_size, dim=0)[self.dp_rank]
|
||||||
|
|
||||||
state_[k] = v.detach().clone().to(device)
|
state_[k] = v.detach().clone().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
return state_
|
return state_
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user