[checkpointio] support non blocking pin load (#6172)

* [checkpointio] support non blocking pin load

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-12-25 17:03:25 +08:00
committed by GitHub
parent 836992438f
commit af06d162cf
15 changed files with 484 additions and 174 deletions

View File

@@ -26,12 +26,21 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
def load_unsharded_model(
self,
model: ModelWrapper,
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
super().load_unsharded_model(
model.unwrap(), checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
@@ -45,12 +54,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
def load_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load optimizer from checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
super().load_unsharded_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
@@ -101,12 +114,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from sharded checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
super().load_sharded_model(
model.unwrap(),
checkpoint_index_file,
strict,
use_safetensors,
load_sub_module,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
def save_sharded_optimizer(
self,
@@ -131,12 +154,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
optimizer: Optimizer,
index_file_path: str,
prefix: Optional[str] = None,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer from sharded checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
super().load_sharded_optimizer(
optimizer.unwrap(), index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False