[checkpointio] fix checkpoint for 3d (#6187)

* fix checkpoint io for 3d

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update hybrid_parallel_checkpoint_io.py

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111 2025-02-12 11:54:55 +08:00 committed by GitHub
parent 2b415e5999
commit 5c09d726a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
import copy
import logging
import os
from collections import defaultdict
from functools import reduce
from pathlib import Path
from shutil import rmtree
@ -10,6 +11,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
@ -37,7 +39,6 @@ from .utils import (
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
@ -724,26 +725,37 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
self.load_states_into_optimizer(optimizer, state_dict, id_map)
loaded_file.add(filename)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def load_states_into_optimizer(self, optimizer: Optimizer, state_dict: dict, id_map: dict):
state_dict = {int(k): v for k, v in state_dict.items()}
new_states = defaultdict(dict)
master_to_working_map = optimizer.get_master_to_working_map()
for k, state in state_dict.items():
if k in id_map:
param = id_map[k]
device = param.device
dtype = param.dtype
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
new_states[param] = self.shard_from_complete_optimizer_state(
state,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
dtype=dtype,
inplace=True,
)
optimizer.optim.state.update(new_states)
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
@ -988,22 +1000,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
id_map[param_id] = param
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
if param is None:
continue
device = param.device
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
)
optimizer.optim.state[param] = sharded_state
self.load_states_into_optimizer(optimizer, state_dict["state"], id_map)
sharded_optimizer_loading_epilogue(optimizer.optim)
@ -1086,6 +1083,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
current_shape: torch.Size,
original_shape: torch.Size,
device: torch.device,
dtype: torch.dtype,
inplace: bool,
) -> OrderedDict:
"""
@ -1135,7 +1133,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
slice_size = v.numel() // self.global_dp_size
v = v.split(slice_size, dim=0)[self.dp_rank]
state_[k] = v.detach().clone().to(device)
state_[k] = v.detach().clone().to(device=device, dtype=dtype)
return state_