Next commit [checkpointio] Unsharded Optimizer Checkpoint for Gemini Plugin (#4141)

* [checkpointio] unsharded optimizer checkpoint for Gemini plugin

* [checkpointio] unsharded optimizer checkpoint for Gemini using all_gather
This commit is contained in:
Baizhou Zhang
2023-07-07 16:33:06 +08:00
committed by GitHub
parent fee32a3b78
commit 58913441a1
9 changed files with 684 additions and 83 deletions

View File

@@ -152,6 +152,7 @@ class CheckpointIO(ABC):
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
index_file_exists, index_file_path = has_index_file(checkpoint)
if Path(checkpoint).is_dir() and not index_file_exists:
@@ -186,6 +187,7 @@ class CheckpointIO(ABC):
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""
if shard:
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
else: