mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 09:59:38 +00:00
[checkpointio]support asyncio for 3d (#6152)
* fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from colossalai.tensor.d_tensor import (
|
||||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
from colossalai.utils.safetensors import _flatten_optim_state_dict
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
@@ -266,6 +267,63 @@ def save_state_dict_shards(
|
||||
|
||||
|
||||
def async_save_state_dict_shards(
|
||||
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
|
||||
checkpoint: str,
|
||||
index_file: "CheckpointIndexFile",
|
||||
base_filename: str,
|
||||
is_master: bool,
|
||||
use_pp_format: bool = False,
|
||||
state_preprocess: bool = False,
|
||||
) -> Tuple[int, list]:
|
||||
"""
|
||||
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
|
||||
Args:
|
||||
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
|
||||
checkpoint (str): The path of checkpoint directory as string.
|
||||
index_file (CheckpointIndexFile): The index file object to be updated.
|
||||
base_filename (str): Decides the prefix of filenames of shards.
|
||||
is_master (bool): Whether current rank is main process.
|
||||
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
|
||||
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
|
||||
|
||||
Returns:
|
||||
int: the total size of shards
|
||||
"""
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
total_size = 0
|
||||
shard_filenames = []
|
||||
writers = []
|
||||
for idx, shard_pair in enumerate(sharded_state_dict):
|
||||
shard, current_size = shard_pair
|
||||
# Just loop over the sharder and gather to other ranks if not master
|
||||
if not is_master:
|
||||
del shard
|
||||
continue
|
||||
shard_file = get_shard_filename(base_filename, idx)
|
||||
total_size = total_size + current_size
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
|
||||
if state_preprocess:
|
||||
state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator=".")
|
||||
else:
|
||||
state_dict = shard
|
||||
|
||||
# Only save on master rank.
|
||||
writer = save(checkpoint_file_path, state_dict=state_dict)
|
||||
writers.append(writer)
|
||||
shard_filenames.append(shard_file)
|
||||
del shard
|
||||
|
||||
# Clean folder, deleted unneeded files.
|
||||
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
|
||||
|
||||
return total_size, writers
|
||||
|
||||
|
||||
def async_move_save_state_dict_shards(
|
||||
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
|
||||
checkpoint: str,
|
||||
index_file: "CheckpointIndexFile",
|
||||
@@ -273,6 +331,7 @@ def async_save_state_dict_shards(
|
||||
is_master: bool,
|
||||
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
|
||||
use_pp_format: bool = False,
|
||||
state_preprocess: bool = False,
|
||||
) -> Tuple[int, Dict[str, torch.Tensor], list]:
|
||||
"""
|
||||
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
|
||||
@@ -309,14 +368,19 @@ def async_save_state_dict_shards(
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
|
||||
if pinned_state_dict is not None:
|
||||
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
|
||||
if state_preprocess:
|
||||
state_dict, _ = _flatten_optim_state_dict(state_dict=shard)
|
||||
else:
|
||||
sub_pinned_state_dict = create_pinned_state_dict(shard)
|
||||
state_dict = shard
|
||||
|
||||
if pinned_state_dict is not None:
|
||||
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in state_dict.keys()}
|
||||
else:
|
||||
sub_pinned_state_dict = create_pinned_state_dict(state_dict)
|
||||
returned_state_dict.update(sub_pinned_state_dict)
|
||||
|
||||
# Only save on master rank.
|
||||
writer = move_and_save(checkpoint_file_path, shard, sub_pinned_state_dict)
|
||||
writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict)
|
||||
writers.append(writer)
|
||||
shard_filenames.append(shard_file)
|
||||
del shard
|
||||
@@ -327,7 +391,11 @@ def async_save_state_dict_shards(
|
||||
return total_size, returned_state_dict, writers
|
||||
|
||||
|
||||
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
def shard_model_checkpoint(
|
||||
state_dict: torch.Tensor,
|
||||
max_shard_size: int = 1024,
|
||||
pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
@@ -336,6 +404,11 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
|
||||
|
||||
for key, weight in state_dict.items():
|
||||
if not is_distributed_tensor(weight):
|
||||
if pinned_state_dicts is not None:
|
||||
if key not in pinned_state_dicts:
|
||||
pinned_state_dicts[key] = torch.empty_like(weight, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[key].copy_(weight)
|
||||
weight = pinned_state_dicts[key]
|
||||
block, block_size = state_dict_sharder.append_param(key, weight)
|
||||
|
||||
if block != None:
|
||||
@@ -345,7 +418,9 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
|
||||
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
def shard_optimizer_checkpoint(
|
||||
state_dict: dict, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""
|
||||
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
@@ -356,6 +431,15 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
||||
state_dict_sharder = StateDictSharder(max_shard_size)
|
||||
|
||||
for param_id, state in states.items():
|
||||
if pinned_state_dicts is not None:
|
||||
if param_id not in pinned_state_dicts:
|
||||
pinned_state_dicts[param_id] = {}
|
||||
for k, v in state.items():
|
||||
if k not in pinned_state_dicts[param_id]:
|
||||
pinned_state_dicts[param_id][k] = torch.empty_like(v, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[param_id][k].copy_(v)
|
||||
state[k] = pinned_state_dicts[param_id][k]
|
||||
|
||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state)
|
||||
if block != None:
|
||||
yield block, block_size
|
||||
|
Reference in New Issue
Block a user