[checkpointio] support non blocking pin load (#6172)

* [checkpointio] support non blocking pin load

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-12-25 17:03:25 +08:00
committed by GitHub
parent 836992438f
commit af06d162cf
15 changed files with 484 additions and 174 deletions

View File

@@ -510,7 +510,14 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
f"index located at {final_index_file_path}."
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
def load_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint_index_file: str,
prefix: str = "",
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load sharded optimizer with the given path to index file of checkpoint folder.
@@ -795,7 +802,14 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
dist.barrier()
# Copied from colossalai.moe
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
def load_unsharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
strict: bool = False,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer from a file with given path.