[checkpointio] General Checkpointing of Sharded Optimizers (#3984)

This commit is contained in:
Baizhou Zhang
2023-06-15 15:21:26 +08:00
committed by GitHub
parent 8bcad73677
commit c9cff7e7fa
8 changed files with 399 additions and 38 deletions

View File

@@ -111,7 +111,7 @@ class CheckpointIndexFile:
return True
return False
def get_checkpoint_fileanames(self) -> List[str]:
def get_checkpoint_filenames(self) -> List[str]:
"""
Get the set of checkpoint filenames in the weight map.
@@ -159,6 +159,18 @@ class CheckpointIndexFile:
"""
return list(self.weight_map.keys())
def get_param_group_filename(self) -> Union[str, None]:
"""
Get the file name of param_group file if this is a checkpoint for optimizer.
Returns:
str: param_group file name
"""
filename = self.metadata.get("param_groups", None)
if filename:
return str(self.root_path.joinpath(filename))
else:
return None
def write_index_file(self, save_index_file):
"""
Write index file.