mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[checkpointio] Sharded Optimizer Checkpoint for Gemini Plugin (#4302)
* sharded optimizer checkpoint for gemini plugin * modify test to reduce testing time * update doc * fix bug when keep_gatherd is true under GeminiPlugin
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# coding=utf-8
|
||||
import os
|
||||
import re
|
||||
from collections import abc as container_abcs
|
||||
from collections import defaultdict
|
||||
@@ -103,6 +104,43 @@ def unwrap_optimizer(optimizer: OptimizerWrapper):
|
||||
return unwrapped_optim
|
||||
|
||||
|
||||
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
|
||||
checkpoint: str,
|
||||
index_file: "CheckpointIndexFile",
|
||||
base_filename: str,
|
||||
is_master: bool,
|
||||
use_safetensors: bool = False) -> int:
|
||||
'''
|
||||
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
|
||||
Args:
|
||||
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
|
||||
checkpoint (str): The path of checkpoint directory as string.
|
||||
index_file (CheckpointIndexFile): The index file object to be updated.
|
||||
base_filename (str): Decides the prefix of filenames of shards.
|
||||
is_master (bool): Whether current rank is master.
|
||||
use_safetensors (bool): Whether to use safetensors to save checkpoint.
|
||||
|
||||
Returns:
|
||||
int: the total size of shards
|
||||
'''
|
||||
|
||||
total_size = 0
|
||||
for idx, shard_pair in enumerate(sharded_state_dict):
|
||||
if not is_master:
|
||||
continue
|
||||
shard, current_size = shard_pair
|
||||
shard_file = get_shard_filename(base_filename, idx)
|
||||
total_size = total_size + current_size
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
|
||||
# Only save on master rank.
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
|
Reference in New Issue
Block a user