mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[checkpointio] support non blocking pin load (#6172)
* [checkpointio] support non blocking pin load * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -288,7 +288,14 @@ class Booster:
|
||||
|
||||
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
|
||||
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
||||
def load_model(
|
||||
self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
strict: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
) -> None:
|
||||
"""Load model from checkpoint.
|
||||
|
||||
Args:
|
||||
@@ -298,8 +305,12 @@ class Booster:
|
||||
strict (bool, optional): whether to strictly enforce that the keys
|
||||
in :attr:`state_dict` match the keys returned by this module's
|
||||
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
self.checkpoint_io.load_model(model, checkpoint, strict)
|
||||
self.checkpoint_io.load_model(
|
||||
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_model(
|
||||
self,
|
||||
@@ -338,18 +349,25 @@ class Booster:
|
||||
use_async=use_async,
|
||||
)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
|
||||
def load_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
) -> None:
|
||||
"""Load optimizer from checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): An optimizer boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
||||
self.checkpoint_io.load_optimizer(
|
||||
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_optimizer(
|
||||
self,
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import gc
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
@@ -97,13 +96,22 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
|
||||
def load_unsharded_model(
|
||||
self,
|
||||
model: GeminiDDP,
|
||||
checkpoint: str,
|
||||
strict: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load model from checkpoint with automatic unwrapping.
|
||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||
"""
|
||||
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||
super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||
super().load_unsharded_model(
|
||||
model, checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||
@@ -131,13 +139,17 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: GeminiOptimizer, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Loading unsharded optimizer from checkpoint file.
|
||||
For each process, only loading optimizer states of parameters it controls.
|
||||
"""
|
||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
|
||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
||||
super().load_unsharded_optimizer(
|
||||
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
@@ -206,13 +218,27 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
)
|
||||
|
||||
def load_sharded_model(
|
||||
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
|
||||
self,
|
||||
model: GeminiDDP,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load shard model, load model from multiple files.
|
||||
"""
|
||||
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
||||
return super().load_sharded_model(
|
||||
model,
|
||||
checkpoint_index_file,
|
||||
strict,
|
||||
use_safetensors,
|
||||
load_sub_module=False,
|
||||
low_cpu_mem_mode=low_cpu_mem_mode,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
@@ -289,7 +315,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: GeminiOptimizer,
|
||||
checkpoint_index_file: Path,
|
||||
prefix: str,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Loading sharded optimizer from checkpoint folder, with index file given.
|
||||
For each process, only loading optimizer states of parameters it controls.
|
||||
@@ -322,9 +355,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict_shard = load_flat(shard_file)
|
||||
else:
|
||||
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
|
||||
optimizer.load_param_states(state_dict_shard)
|
||||
del state_dict_shard
|
||||
gc.collect()
|
||||
|
||||
optimizer.optimizer_loading_epilogue()
|
||||
|
||||
|
@@ -20,6 +20,7 @@ from torch.utils.data import DataLoader
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
create_pinned_state_dict,
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
load_param_groups_into_optimizer,
|
||||
@@ -145,7 +146,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
use_async = checkpoint.endswith(".safetensors")
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import load_flat
|
||||
@@ -153,6 +156,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
checkpoint = load_flat(checkpoint)
|
||||
else:
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
if not low_cpu_mem_mode:
|
||||
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
@@ -239,7 +244,14 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
index_file_path: str,
|
||||
prefix: str,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""Load sharded optimizer with the given path to index file.
|
||||
|
||||
Args:
|
||||
@@ -283,14 +295,28 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
v_list = v.split(v.numel() // self.coordinator.world_size)
|
||||
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
|
||||
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
|
||||
if low_cpu_mem_mode:
|
||||
state_dict[param_idx][k] = state_dict[param_idx][k].clone()
|
||||
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
||||
def load_unsharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint: str,
|
||||
strict: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
super().load_unsharded_model(model, checkpoint, strict)
|
||||
super().load_unsharded_model(
|
||||
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
model.update_master_params()
|
||||
|
||||
def load_sharded_model(
|
||||
@@ -300,10 +326,20 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
super().load_sharded_model(
|
||||
model,
|
||||
checkpoint_index_file,
|
||||
strict,
|
||||
use_safetensors,
|
||||
load_sub_module,
|
||||
low_cpu_mem_mode=low_cpu_mem_mode,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
model.update_master_params()
|
||||
|
||||
def save_unsharded_model(
|
||||
|
@@ -26,12 +26,21 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
self.coordinator = DistCoordinator()
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
||||
def load_unsharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint: str,
|
||||
strict: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load model from checkpoint.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
|
||||
super().load_unsharded_model(
|
||||
model.unwrap(), checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_unsharded_model(
|
||||
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||
@@ -45,12 +54,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
|
||||
)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Load optimizer from checkpoint.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
||||
super().load_unsharded_optimizer(
|
||||
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||
@@ -101,12 +114,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load model from sharded checkpoint.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
super().load_sharded_model(
|
||||
model.unwrap(),
|
||||
checkpoint_index_file,
|
||||
strict,
|
||||
use_safetensors,
|
||||
load_sub_module,
|
||||
low_cpu_mem_mode=low_cpu_mem_mode,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
@@ -131,12 +154,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
optimizer: Optimizer,
|
||||
index_file_path: str,
|
||||
prefix: Optional[str] = None,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load optimizer from sharded checkpoint.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
|
||||
super().load_sharded_optimizer(
|
||||
optimizer.unwrap(), index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
|
@@ -43,13 +43,17 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
self.coordinator = DistCoordinator()
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
|
||||
def load_unsharded_model(
|
||||
self, model: ModelWrapper, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
|
||||
model = model.unwrap()
|
||||
checkpoint = utils.load_state_dict(checkpoint)
|
||||
model.load_state_dict(checkpoint)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
if checkpoint.endswith(".safetensors"):
|
||||
checkpoint = load_flat(checkpoint, seperator=".")
|
||||
@@ -232,6 +236,8 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load model to checkpoint but only on master process.
|
||||
@@ -354,7 +360,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
index_file_path: str,
|
||||
size_per_shard: int,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user