[checkpointio] support non blocking pin load (#6172)

* [checkpointio] support non blocking pin load

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-12-25 17:03:25 +08:00
committed by GitHub
parent 836992438f
commit af06d162cf
15 changed files with 484 additions and 174 deletions

View File

@@ -1,4 +1,3 @@
import gc
import os
import random
from pathlib import Path
@@ -97,13 +96,22 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
def load_unsharded_model(
self,
model: GeminiDDP,
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from checkpoint with automatic unwrapping.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict)
super().load_unsharded_model(
model, checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
@@ -131,13 +139,17 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
def load_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Loading unsharded optimizer from checkpoint file.
For each process, only loading optimizer states of parameters it controls.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
super().load_unsharded_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_sharded_model(
self,
@@ -206,13 +218,27 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
)
def load_sharded_model(
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
self,
model: GeminiDDP,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load shard model, load model from multiple files.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
return super().load_sharded_model(
model,
checkpoint_index_file,
strict,
use_safetensors,
load_sub_module=False,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
def save_sharded_optimizer(
self,
@@ -289,7 +315,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
ranks=[0],
)
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
def load_sharded_optimizer(
self,
optimizer: GeminiOptimizer,
checkpoint_index_file: Path,
prefix: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
@@ -322,9 +355,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
state_dict_shard = load_flat(shard_file)
else:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
if not low_cpu_mem_mode:
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
optimizer.load_param_states(state_dict_shard)
del state_dict_shard
gc.collect()
optimizer.optimizer_loading_epilogue()