[checkpointio]support asyncio for 3d (#6152)

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-12-23 10:24:22 +08:00
committed by GitHub
parent aaafb38851
commit 130229fdcb
17 changed files with 776 additions and 188 deletions

View File

@@ -8,10 +8,12 @@ from typing import Optional
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.utils.safetensors import load_flat
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
async_save_state_dict_shards,
async_move_save_state_dict_shards,
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
@@ -47,10 +49,6 @@ class GeneralCheckpointIO(CheckpointIO):
):
state_dict = model.state_dict()
# TODO(FrankLeeeee): add support for gather_dtensor
if gather_dtensor:
pass
if use_async:
from colossalai.utils.safetensors import move_and_save
@@ -58,7 +56,6 @@ class GeneralCheckpointIO(CheckpointIO):
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
self.async_writers.append(writer)
else:
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
@@ -83,7 +80,10 @@ class GeneralCheckpointIO(CheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
if shard_file.endswith(".safetensors"):
state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
load_states_into_optimizer(optimizer, state_dict, id_map)
sharded_optimizer_loading_epilogue(optimizer)
@@ -116,7 +116,7 @@ class GeneralCheckpointIO(CheckpointIO):
sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
# Store the information of param groups to param_group_file.
@@ -126,14 +126,28 @@ class GeneralCheckpointIO(CheckpointIO):
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
use_safetensors=False,
)
if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(optimizer), None)
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
pinned_state_dict=pinned_state_dict,
state_preprocess=True,
)
self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
use_safetensors=False,
)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
@@ -145,7 +159,10 @@ class GeneralCheckpointIO(CheckpointIO):
)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint)
if checkpoint.endswith(".safetensors"):
checkpoint = load_flat(checkpoint)
else:
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
def save_unsharded_optimizer(
@@ -156,7 +173,22 @@ class GeneralCheckpointIO(CheckpointIO):
use_async: bool = False,
):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
state_dict = optimizer.state_dict()
if use_async:
from colossalai.utils.safetensors import _flatten_optim_state_dict, move_and_save
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
writer = move_and_save(
path=checkpoint,
state_dict=flatten_state_dict,
state_dict_pinned=self.pinned_state_dicts[id(optimizer)],
metadata=metadata,
)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def save_sharded_model(
self,
@@ -186,7 +218,7 @@ class GeneralCheckpointIO(CheckpointIO):
if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,