[checkpointio] support async model save (#6131)

* [checkpointio] support async model save

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-11-14 11:38:10 +08:00
parent 5a03d2696d
commit d4a436051d
7 changed files with 209 additions and 28 deletions

View File

@@ -8,9 +8,13 @@ from typing import Optional
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.utils.safetensors import move_and_save
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
async_save_state_dict_shards,
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
@@ -40,15 +44,27 @@ class GeneralCheckpointIO(CheckpointIO):
checkpoint = load_state_dict(checkpoint)
model.load_state_dict(checkpoint, strict=strict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
state_dict = model.state_dict()
# TODO(FrankLeeeee): add support for gather_dtensor
if gather_dtensor:
pass
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
"""
@@ -151,6 +167,7 @@ class GeneralCheckpointIO(CheckpointIO):
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
implement this method as it can be supported by Huggingface model,
@@ -168,16 +185,30 @@ class GeneralCheckpointIO(CheckpointIO):
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path)
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors,
)
if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors,
)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)