mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[checkpointio] General Checkpointing of Sharded Optimizers (#3984)
This commit is contained in:
@@ -111,7 +111,7 @@ class CheckpointIndexFile:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_checkpoint_fileanames(self) -> List[str]:
|
||||
def get_checkpoint_filenames(self) -> List[str]:
|
||||
"""
|
||||
Get the set of checkpoint filenames in the weight map.
|
||||
|
||||
@@ -159,6 +159,18 @@ class CheckpointIndexFile:
|
||||
"""
|
||||
return list(self.weight_map.keys())
|
||||
|
||||
def get_param_group_filename(self) -> Union[str, None]:
|
||||
"""
|
||||
Get the file name of param_group file if this is a checkpoint for optimizer.
|
||||
Returns:
|
||||
str: param_group file name
|
||||
"""
|
||||
filename = self.metadata.get("param_groups", None)
|
||||
if filename:
|
||||
return str(self.root_path.joinpath(filename))
|
||||
else:
|
||||
return None
|
||||
|
||||
def write_index_file(self, save_index_file):
|
||||
"""
|
||||
Write index file.
|
||||
|
Reference in New Issue
Block a user