[checkpointio] fix checkpoint for 3d (#6187)

* fix checkpoint io for 3d

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update hybrid_parallel_checkpoint_io.py

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111 2025-02-12 11:54:55 +08:00 committed by GitHub
parent 2b415e5999
commit 5c09d726a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
import copy import copy
import logging import logging
import os import os
from collections import defaultdict
from functools import reduce from functools import reduce
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
@ -10,6 +11,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
@ -37,7 +39,6 @@ from .utils import (
load_shard_state_dict, load_shard_state_dict,
load_state_dict, load_state_dict,
load_state_dict_into_model, load_state_dict_into_model,
load_states_into_optimizer,
save_config_file, save_config_file,
save_param_groups, save_param_groups,
save_state_dict, save_state_dict,
@ -724,26 +725,37 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
if not low_cpu_mem_mode: if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) self.load_states_into_optimizer(optimizer, state_dict, id_map)
loaded_file.add(filename) loaded_file.add(filename)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim) sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master(): if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def load_states_into_optimizer(self, optimizer: Optimizer, state_dict: dict, id_map: dict):
state_dict = {int(k): v for k, v in state_dict.items()}
new_states = defaultdict(dict)
master_to_working_map = optimizer.get_master_to_working_map()
for k, state in state_dict.items():
if k in id_map:
param = id_map[k]
device = param.device
dtype = param.dtype
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
new_states[param] = self.shard_from_complete_optimizer_state(
state,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
dtype=dtype,
inplace=True,
)
optimizer.optim.state.update(new_states)
def save_unsharded_model( def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
): ):
@ -988,22 +1000,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for param in pg["params"]: for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
id_map[param_id] = param id_map[param_id] = param
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True) self.load_states_into_optimizer(optimizer, state_dict["state"], id_map)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
if param is None:
continue
device = param.device
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim) sharded_optimizer_loading_epilogue(optimizer.optim)
@ -1086,6 +1083,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
current_shape: torch.Size, current_shape: torch.Size,
original_shape: torch.Size, original_shape: torch.Size,
device: torch.device, device: torch.device,
dtype: torch.dtype,
inplace: bool, inplace: bool,
) -> OrderedDict: ) -> OrderedDict:
""" """
@ -1135,7 +1133,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
slice_size = v.numel() // self.global_dp_size slice_size = v.numel() // self.global_dp_size
v = v.split(slice_size, dim=0)[self.dp_rank] v = v.split(slice_size, dim=0)[self.dp_rank]
state_[k] = v.detach().clone().to(device) state_[k] = v.detach().clone().to(device=device, dtype=dtype)
return state_ return state_