mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 18:39:56 +00:00
[checkpointio] support async model save (#6131)
* [checkpointio] support async model save * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -5,7 +5,7 @@ from collections import abc as container_abcs
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -19,6 +19,7 @@ from colossalai.tensor.d_tensor import (
|
||||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
from colossalai.utils.safetensors import move_and_save
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
@@ -263,6 +264,71 @@ def save_state_dict_shards(
|
||||
return total_size
|
||||
|
||||
|
||||
def async_save_state_dict_shards(
|
||||
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
|
||||
checkpoint: str,
|
||||
index_file: "CheckpointIndexFile",
|
||||
base_filename: str,
|
||||
is_master: bool,
|
||||
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
|
||||
n_write_entries: int,
|
||||
use_pp_format: bool = False,
|
||||
) -> Tuple[int, Dict[str, torch.Tensor], list]:
|
||||
"""
|
||||
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
|
||||
Args:
|
||||
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
|
||||
checkpoint (str): The path of checkpoint directory as string.
|
||||
index_file (CheckpointIndexFile): The index file object to be updated.
|
||||
base_filename (str): Decides the prefix of filenames of shards.
|
||||
is_master (bool): Whether current rank is main process.
|
||||
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
|
||||
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
|
||||
|
||||
Returns:
|
||||
int: the total size of shards
|
||||
"""
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
total_size = 0
|
||||
shard_filenames = []
|
||||
if pinned_state_dict is None:
|
||||
returned_state_dict = {}
|
||||
else:
|
||||
returned_state_dict = pinned_state_dict
|
||||
writers = []
|
||||
for idx, shard_pair in enumerate(sharded_state_dict):
|
||||
shard, current_size = shard_pair
|
||||
# Just loop over the sharder and gather to other ranks if not master
|
||||
if not is_master:
|
||||
del shard
|
||||
continue
|
||||
shard_file = get_shard_filename(base_filename, idx)
|
||||
total_size = total_size + current_size
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
|
||||
writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread")
|
||||
writers.append(writer)
|
||||
|
||||
if pinned_state_dict is not None:
|
||||
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
|
||||
else:
|
||||
sub_pinned_state_dict = create_pinned_state_dict(shard)
|
||||
returned_state_dict.update(sub_pinned_state_dict)
|
||||
|
||||
# Only save on master rank.
|
||||
move_and_save(writer, shard, sub_pinned_state_dict)
|
||||
shard_filenames.append(shard_file)
|
||||
del shard
|
||||
|
||||
# Clean folder, deleted unneeded files.
|
||||
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
|
||||
|
||||
return total_size, returned_state_dict, writers
|
||||
|
||||
|
||||
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
@@ -799,3 +865,10 @@ def get_shard_filename(weights_name: str, idx: int):
|
||||
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
|
||||
shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
|
||||
return shard_file
|
||||
|
||||
|
||||
def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]):
|
||||
pin_mem = dict()
|
||||
for name, tensor in state_dict.items():
|
||||
pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu")
|
||||
return pin_mem
|
||||
|
Reference in New Issue
Block a user