[checkpointio] sharded optimizer checkpoint for DDP plugin (#4002)

This commit is contained in:
Baizhou Zhang
2023-06-16 14:14:05 +08:00
committed by GitHub
parent 725af3eeeb
commit 822c3d4d66
6 changed files with 79 additions and 34 deletions

View File

@@ -148,6 +148,9 @@ class CheckpointIO(ABC):
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
index_file_exists, index_file_path = has_index_file(checkpoint)
@@ -157,7 +160,7 @@ class CheckpointIO(ABC):
if index_file_exists:
# the existence of index file means it is a sharded checkpoint
self.load_sharded_optimizer(optimizer, index_file_path, prefix, size_per_shard)
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
else:
self.load_unsharded_optimizer(optimizer, checkpoint)
@@ -251,7 +254,7 @@ class CheckpointIO(ABC):
# ========================================================
@abstractmethod
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
"""
Load optimizer from sharded checkpoint.
@@ -259,7 +262,6 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded.
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass