[checkpointio] support non blocking pin load (#6172)

* [checkpointio] support non blocking pin load

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu 2024-12-25 17:03:25 +08:00 committed by GitHub
parent 836992438f
commit af06d162cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 484 additions and 174 deletions

View File

@ -288,7 +288,14 @@ class Booster:
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
def load_model(
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
) -> None:
"""Load model from checkpoint.
Args:
@ -298,8 +305,12 @@ class Booster:
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
self.checkpoint_io.load_model(model, checkpoint, strict)
self.checkpoint_io.load_model(
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_model(
self,
@ -338,18 +349,25 @@ class Booster:
use_async=use_async,
)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
def load_optimizer(
self,
optimizer: Optimizer,
checkpoint: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
) -> None:
"""Load optimizer from checkpoint.
Args:
optimizer (Optimizer): An optimizer boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
self.checkpoint_io.load_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_optimizer(
self,

View File

@ -1,4 +1,3 @@
import gc
import os
import random
from pathlib import Path
@ -97,13 +96,22 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
def load_unsharded_model(
self,
model: GeminiDDP,
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from checkpoint with automatic unwrapping.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict)
super().load_unsharded_model(
model, checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
@ -131,13 +139,17 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
def load_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Loading unsharded optimizer from checkpoint file.
For each process, only loading optimizer states of parameters it controls.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
super().load_unsharded_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_sharded_model(
self,
@ -206,13 +218,27 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
)
def load_sharded_model(
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
self,
model: GeminiDDP,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load shard model, load model from multiple files.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
return super().load_sharded_model(
model,
checkpoint_index_file,
strict,
use_safetensors,
load_sub_module=False,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
def save_sharded_optimizer(
self,
@ -289,7 +315,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
ranks=[0],
)
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
def load_sharded_optimizer(
self,
optimizer: GeminiOptimizer,
checkpoint_index_file: Path,
prefix: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
@ -322,9 +355,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
state_dict_shard = load_flat(shard_file)
else:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
if not low_cpu_mem_mode:
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
optimizer.load_param_states(state_dict_shard)
del state_dict_shard
gc.collect()
optimizer.optimizer_loading_epilogue()

View File

@ -20,6 +20,7 @@ from torch.utils.data import DataLoader
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
from colossalai.checkpoint_io.utils import (
create_pinned_state_dict,
get_optimizer_base_filenames,
get_shard_filename,
load_param_groups_into_optimizer,
@ -145,7 +146,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
def load_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
use_async = checkpoint.endswith(".safetensors")
if use_async:
from colossalai.utils.safetensors import load_flat
@ -153,6 +156,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
checkpoint = load_flat(checkpoint)
else:
checkpoint = load_state_dict(checkpoint)
if not low_cpu_mem_mode:
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
optimizer.load_state_dict(checkpoint)
def save_sharded_optimizer(
@ -239,7 +244,14 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
ranks=[0],
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
def load_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
index_file_path: str,
prefix: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""Load sharded optimizer with the given path to index file.
Args:
@ -283,14 +295,28 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self.coordinator.world_size)
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
if low_cpu_mem_mode:
state_dict[param_idx][k] = state_dict[param_idx][k].clone()
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_states_into_optimizer(optimizer, state_dict, id_map)
sharded_optimizer_loading_epilogue(optimizer)
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
def load_unsharded_model(
self,
model: ModelWrapper,
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_unsharded_model(model, checkpoint, strict)
super().load_unsharded_model(
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
model.update_master_params()
def load_sharded_model(
@ -300,10 +326,20 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
super().load_sharded_model(
model,
checkpoint_index_file,
strict,
use_safetensors,
load_sub_module,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
model.update_master_params()
def save_unsharded_model(

View File

@ -26,12 +26,21 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
def load_unsharded_model(
self,
model: ModelWrapper,
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
super().load_unsharded_model(
model.unwrap(), checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
@ -45,12 +54,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
def load_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load optimizer from checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
super().load_unsharded_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
@ -101,12 +114,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from sharded checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
super().load_sharded_model(
model.unwrap(),
checkpoint_index_file,
strict,
use_safetensors,
load_sub_module,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
def save_sharded_optimizer(
self,
@ -131,12 +154,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
optimizer: Optimizer,
index_file_path: str,
prefix: Optional[str] = None,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer from sharded checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
super().load_sharded_optimizer(
optimizer.unwrap(), index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False

View File

@ -43,13 +43,17 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
def load_unsharded_model(
self, model: ModelWrapper, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
model = model.unwrap()
checkpoint = utils.load_state_dict(checkpoint)
model.load_state_dict(checkpoint)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
def load_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
if checkpoint.endswith(".safetensors"):
checkpoint = load_flat(checkpoint, seperator=".")
@ -232,6 +236,8 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model to checkpoint but only on master process.
@ -354,7 +360,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
f"index located at {save_index_file}."
)
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
size_per_shard: int,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer to checkpoint but only on master process.
"""

View File

@ -85,7 +85,12 @@ class CheckpointIO(ABC):
self._sync_io()
def load_model(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
) -> Union[nn.Module, ModelWrapper]:
"""
Load model from checkpoint.
@ -100,6 +105,8 @@ class CheckpointIO(ABC):
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
# since we only support loaded sharded and unsharded weight format
# containing no distributed tensors, dtensor -> full tensor conversion
@ -111,17 +118,25 @@ class CheckpointIO(ABC):
origin_model = model
if index_file_exists:
self.load_sharded_model(model, index_file_path, strict)
self.load_sharded_model(
model, index_file_path, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
else:
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
self.load_unsharded_model(
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
else:
path = Path(checkpoint, WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
self.load_unsharded_model(
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
else:
self.load_unsharded_model(model, checkpoint, strict)
self.load_unsharded_model(
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
return origin_model
@ -178,7 +193,14 @@ class CheckpointIO(ABC):
else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
def load_optimizer(
self,
optimizer: Optimizer,
checkpoint: str,
prefix: str = None,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer from checkpoint.
@ -187,7 +209,8 @@ class CheckpointIO(ABC):
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
index_file_exists, index_file_path = has_index_file(checkpoint)
@ -198,9 +221,13 @@ class CheckpointIO(ABC):
if index_file_exists:
# the existence of index file means it is a sharded checkpoint
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
self.load_sharded_optimizer(
optimizer, index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
else:
self.load_unsharded_optimizer(optimizer, checkpoint)
self.load_unsharded_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_optimizer(
self,
@ -238,7 +265,9 @@ class CheckpointIO(ABC):
# Abstract methods for model loading/saving implementation
# ========================================================
@abstractmethod
def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool):
def load_sharded_model(
self, model: nn.Module, index_file_path: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load model from sharded checkpoint.
@ -247,10 +276,14 @@ class CheckpointIO(ABC):
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
@abstractmethod
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
def load_unsharded_model(
self, model: nn.Module, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load model from unsharded checkpoint.
@ -259,6 +292,8 @@ class CheckpointIO(ABC):
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
@abstractmethod
@ -303,7 +338,14 @@ class CheckpointIO(ABC):
# ========================================================
@abstractmethod
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
prefix: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer from sharded checkpoint.
@ -311,16 +353,22 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded.
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
@abstractmethod
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
def load_unsharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load optimizer from unsharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
@abstractmethod

View File

@ -1,4 +1,3 @@
import gc
import logging
import os
from functools import reduce
@ -40,8 +39,17 @@ class GeneralCheckpointIO(CheckpointIO):
Checkpoint IO
"""
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
def load_unsharded_model(
self,
model: nn.Module,
checkpoint: str,
strict: bool,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
checkpoint = load_state_dict(checkpoint)
if not low_cpu_mem_mode:
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
model.load_state_dict(checkpoint, strict=strict)
def save_unsharded_model(
@ -60,7 +68,14 @@ class GeneralCheckpointIO(CheckpointIO):
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
prefix: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load sharded optimizer with the given path to index file.
"""
@ -84,6 +99,8 @@ class GeneralCheckpointIO(CheckpointIO):
state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_states_into_optimizer(optimizer, state_dict, id_map)
sharded_optimizer_loading_epilogue(optimizer)
@ -158,11 +175,15 @@ class GeneralCheckpointIO(CheckpointIO):
f"index located at {save_index_file}."
)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
def load_unsharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
if checkpoint.endswith(".safetensors"):
checkpoint = load_flat(checkpoint)
else:
checkpoint = load_state_dict(checkpoint)
if not low_cpu_mem_mode:
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
optimizer.load_state_dict(checkpoint)
def save_unsharded_optimizer(
@ -256,6 +277,8 @@ class GeneralCheckpointIO(CheckpointIO):
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
load shard model, load model from multiple files
@ -274,9 +297,9 @@ class GeneralCheckpointIO(CheckpointIO):
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
del state_dict
gc.collect()
if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))

View File

@ -355,7 +355,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
f"index located at {final_index_file_path}."
)
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
def load_sharded_model(
self,
model: ModelWrapper,
checkpoint_index_file: Path,
strict: bool = False,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load sharded model with the given path to index file of checkpoint folder.
@ -403,6 +410,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_state_dict_into_model(
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
@ -632,7 +641,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
f"index located at {final_index_file_path}."
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
def load_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint_index_file: str,
prefix: str = "",
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load sharded optimizer with the given path to index file of checkpoint folder.
@ -706,6 +722,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict = load_flat(file_path)
else:
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
@ -789,7 +807,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
def load_unsharded_model(
self,
model: ModelWrapper,
checkpoint: str,
strict: bool = False,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from a single file with the given path of checkpoint.
@ -812,6 +837,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
# model.load_state_dict can be directly called.
state_dict = load_state_dict(checkpoint)
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
model.load_state_dict(state_dict, strict=strict)
# Update master params if mixed-precision training is enabled.
@ -912,7 +939,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
def load_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load optimizer from a file with given path.
@ -940,6 +969,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict = load_flat(checkpoint)
else:
state_dict = load_state_dict(checkpoint)
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
# Load param_groups.
updated_groups = []

View File

@ -510,7 +510,14 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
f"index located at {final_index_file_path}."
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
def load_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint_index_file: str,
prefix: str = "",
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load sharded optimizer with the given path to index file of checkpoint folder.
@ -795,7 +802,14 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
dist.barrier()
# Copied from colossalai.moe
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
def load_unsharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
strict: bool = False,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer from a file with given path.

View File

@ -1,18 +1,20 @@
# coding=utf-8
import concurrent.futures
import os
import re
from collections import abc as container_abcs
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
import torch
import torch.nn as nn
from packaging.version import Version
from torch.optim import Optimizer
from torch.utils._pytree import tree_map
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from colossalai.accelerator import get_accelerator
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
@ -791,7 +793,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
if key != "step":
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
value = value.to(param.device, non_blocking=True)
return value
elif isinstance(value, dict):
return {k: cast(param, v, key=k) for k, v in value.items()}
@ -811,6 +813,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
elif not strict:
new_states[k] = v
get_accelerator().synchronize()
optimizer.state.update(new_states)
@ -945,8 +948,27 @@ def get_shard_filename(weights_name: str, idx: int):
return shard_file
def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]):
pin_mem = dict()
for name, tensor in state_dict.items():
pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu")
return pin_mem
def _pin_tensor(tensor: torch.Tensor, empty: bool = True) -> torch.Tensor:
if empty:
return torch.empty_like(tensor, pin_memory=True, device="cpu")
return tensor.pin_memory()
def create_pinned_state_dict(
state_dict: Union[Dict[str, torch.Tensor], Dict[int, Dict[str, torch.Tensor]]],
empty: bool = True,
num_threads: int = 1,
) -> Dict[str, torch.Tensor]:
if num_threads == 1:
return tree_map(lambda x: _pin_tensor(x, empty=empty) if isinstance(x, torch.Tensor) else x, state_dict)
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
elems, spec = tree_flatten(state_dict)
future_to_idx = {}
for i, elem in enumerate(elems):
if isinstance(elem, torch.Tensor):
future_to_idx[executor.submit(_pin_tensor, elem, empty)] = i
for future in concurrent.futures.as_completed(future_to_idx):
idx = future_to_idx[future]
elems[idx] = future.result()
return tree_unflatten(elems, spec)

View File

@ -90,8 +90,16 @@ def exam_state_dict_with_origin(
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
@parameterize("use_async", [False, True])
@parameterize("low_cpu_mem_mode", [True, False])
def exam_state_dict(
placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool
placement_config,
shard: bool,
model_name: str,
size_per_shard: int,
tp_size: int,
zero_size: int,
use_async: bool,
low_cpu_mem_mode: bool,
):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
@ -147,12 +155,12 @@ def exam_state_dict(
booster.checkpoint_io._sync_io()
dist.barrier()
booster.load_model(new_model, model_ckpt_path)
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), ignore_dtype=True
)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False))
for group in new_optimizer.param_groups:
assert group["lr"] == 0.1

View File

@ -43,8 +43,11 @@ else:
@parameterize("size_per_shard", [32])
@parameterize("test_config", TEST_CONFIGS)
@parameterize("use_async", [False, True])
@parameterize("low_cpu_mem_mode", [False, True])
@clear_cache_before_run()
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool):
def exam_state_dict(
shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool, low_cpu_mem_mode: bool
):
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
@ -102,9 +105,9 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
booster.load_model(new_model, model_ckpt_path)
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict())
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict())
dist.barrier()

View File

@ -29,7 +29,8 @@ from tests.kit.model_zoo import model_zoo
@parameterize("shard", [False, True])
@parameterize("offload", [False, True])
@parameterize("use_async", [False, True])
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool):
@parameterize("low_cpu_mem_mode", [False, True])
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool, low_cpu_mem_mode: bool):
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
booster = Booster(plugin=plugin)
model = resnet18()
@ -70,7 +71,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
new_optimizer = HybridAdam((new_model.parameters()), lr=0.001)
new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)
booster.load_model(new_model, model_ckpt_path)
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(model.state_dict(), new_model.state_dict())
# check master weight
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
@ -85,7 +86,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
torch.cuda.empty_cache()

View File

@ -1,108 +1,144 @@
import tempfile
import pytest
import torch
from safetensors.torch import load_file
from colossalai.checkpoint_io.utils import create_pinned_state_dict
from colossalai.testing import check_state_dict_equal, clear_cache_before_run
from colossalai.utils import get_current_device
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
def gen_optim_state_dict():
return {
"state": {
0: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
1: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
2: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
},
"param_groups": [
{
"lr": 0.001,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"bias_correction": True,
"params": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
],
}
],
}
def gen_model_state_dict():
return {
"module.weight0": torch.rand((1024, 1024)),
"module.weight1": torch.rand((1024, 1024)),
"module.weight2": torch.rand((1024, 1024)),
}
@pytest.mark.parametrize("empty", [True, False])
@pytest.mark.parametrize("num_threads", [1, 4])
def test_create_pin(empty: bool, num_threads: int):
model_state_dict = gen_model_state_dict()
model_state_dict_pinned = create_pinned_state_dict(model_state_dict, empty=empty, num_threads=num_threads)
for k in model_state_dict.keys():
assert model_state_dict_pinned[k].is_pinned()
if not empty:
assert torch.equal(model_state_dict_pinned[k], model_state_dict[k])
optim_state_dict = gen_optim_state_dict()
optim_state_dict_pinned = create_pinned_state_dict(optim_state_dict, empty=empty, num_threads=num_threads)
for k in optim_state_dict.keys():
if k == "state":
for idx in optim_state_dict[k].keys():
for kk in optim_state_dict[k][idx].keys():
assert optim_state_dict_pinned[k][idx][kk].is_pinned()
if not empty:
assert torch.equal(optim_state_dict_pinned[k][idx][kk], optim_state_dict[k][idx][kk])
else:
assert optim_state_dict[k] == optim_state_dict_pinned[k]
@clear_cache_before_run()
def test_save_load():
with tempfile.TemporaryDirectory() as tempdir:
optimizer_state_dict = {
"state": {
0: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
1: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
2: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
},
"param_groups": [
{
"lr": 0.001,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"bias_correction": True,
"params": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
],
}
],
}
optimizer_state_dict = gen_optim_state_dict()
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
f_writer = save_nested(optimizer_saved_path, optimizer_state_dict)
@ -120,11 +156,7 @@ def test_save_load():
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
model_state_dict = {
"module.weight0": torch.rand((1024, 1024)),
"module.weight1": torch.rand((1024, 1024)),
"module.weight2": torch.rand((1024, 1024)),
}
model_state_dict = gen_model_state_dict()
model_saved_path = f"{tempdir}/save_model.safetensors"
f_writer = save(model_saved_path, model_state_dict)
f_writer.sync_before_step()

View File

@ -15,7 +15,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad
@parameterize("shard", [False, True])
@parameterize("size_per_shard", [16, 128])
@parameterize("use_async", [False, True])
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bool):
@parameterize("low_cpu_mem_mode", [False, True])
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bool, low_cpu_mem_mode: bool):
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
@ -61,10 +62,10 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bo
new_model, new_optimizer, lr_scheduler=new_scheduler
)
booster.load_model(new_model, model_ckpt_path)
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(model.state_dict(), new_model.state_dict())
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict())