[checkpointio]support asyncio for 3d (#6152)

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-12-23 10:24:22 +08:00
committed by GitHub
parent aaafb38851
commit 130229fdcb
17 changed files with 776 additions and 188 deletions

View File

@@ -19,6 +19,7 @@ from colossalai.tensor.d_tensor import (
to_global,
to_global_for_customized_distributed_tensor,
)
from colossalai.utils.safetensors import _flatten_optim_state_dict
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
@@ -266,6 +267,63 @@ def save_state_dict_shards(
def async_save_state_dict_shards(
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
use_pp_format: bool = False,
state_preprocess: bool = False,
) -> Tuple[int, list]:
"""
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards.
is_master (bool): Whether current rank is main process.
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
Returns:
int: the total size of shards
"""
from colossalai.utils.safetensors import save
total_size = 0
shard_filenames = []
writers = []
for idx, shard_pair in enumerate(sharded_state_dict):
shard, current_size = shard_pair
# Just loop over the sharder and gather to other ranks if not master
if not is_master:
del shard
continue
shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
if state_preprocess:
state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator=".")
else:
state_dict = shard
# Only save on master rank.
writer = save(checkpoint_file_path, state_dict=state_dict)
writers.append(writer)
shard_filenames.append(shard_file)
del shard
# Clean folder, deleted unneeded files.
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
return total_size, writers
def async_move_save_state_dict_shards(
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
@@ -273,6 +331,7 @@ def async_save_state_dict_shards(
is_master: bool,
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
use_pp_format: bool = False,
state_preprocess: bool = False,
) -> Tuple[int, Dict[str, torch.Tensor], list]:
"""
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
@@ -309,14 +368,19 @@ def async_save_state_dict_shards(
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
if pinned_state_dict is not None:
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
if state_preprocess:
state_dict, _ = _flatten_optim_state_dict(state_dict=shard)
else:
sub_pinned_state_dict = create_pinned_state_dict(shard)
state_dict = shard
if pinned_state_dict is not None:
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in state_dict.keys()}
else:
sub_pinned_state_dict = create_pinned_state_dict(state_dict)
returned_state_dict.update(sub_pinned_state_dict)
# Only save on master rank.
writer = move_and_save(checkpoint_file_path, shard, sub_pinned_state_dict)
writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict)
writers.append(writer)
shard_filenames.append(shard_file)
del shard
@@ -327,7 +391,11 @@ def async_save_state_dict_shards(
return total_size, returned_state_dict, writers
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
def shard_model_checkpoint(
state_dict: torch.Tensor,
max_shard_size: int = 1024,
pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
@@ -336,6 +404,11 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
for key, weight in state_dict.items():
if not is_distributed_tensor(weight):
if pinned_state_dicts is not None:
if key not in pinned_state_dicts:
pinned_state_dicts[key] = torch.empty_like(weight, pin_memory=True, device="cpu")
pinned_state_dicts[key].copy_(weight)
weight = pinned_state_dicts[key]
block, block_size = state_dict_sharder.append_param(key, weight)
if block != None:
@@ -345,7 +418,9 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
def shard_optimizer_checkpoint(
state_dict: dict, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None
) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
@@ -356,6 +431,15 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
state_dict_sharder = StateDictSharder(max_shard_size)
for param_id, state in states.items():
if pinned_state_dicts is not None:
if param_id not in pinned_state_dicts:
pinned_state_dicts[param_id] = {}
for k, v in state.items():
if k not in pinned_state_dicts[param_id]:
pinned_state_dicts[param_id][k] = torch.empty_like(v, pin_memory=True, device="cpu")
pinned_state_dicts[param_id][k].copy_(v)
state[k] = pinned_state_dicts[param_id][k]
block, block_size = state_dict_sharder.append_optim_state(param_id, state)
if block != None:
yield block, block_size