[checkpointio] General Checkpointing of Sharded Optimizers (#3984)

This commit is contained in:
Baizhou Zhang
2023-06-15 15:21:26 +08:00
committed by GitHub
parent 8bcad73677
commit c9cff7e7fa
8 changed files with 399 additions and 38 deletions

View File

@@ -12,7 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import get_base_filenames, get_shard_filename, save_state_dict
from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
@@ -76,14 +76,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
model: GeminiDDP,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
Save sharded model
"""
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):