mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 12:12:46 +00:00
[checkpointio] support non blocking pin load (#6172)
* [checkpointio] support non blocking pin load * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
836992438f
commit
af06d162cf
@ -288,7 +288,14 @@ class Booster:
|
||||
|
||||
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
|
||||
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
||||
def load_model(
|
||||
self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
strict: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
) -> None:
|
||||
"""Load model from checkpoint.
|
||||
|
||||
Args:
|
||||
@ -298,8 +305,12 @@ class Booster:
|
||||
strict (bool, optional): whether to strictly enforce that the keys
|
||||
in :attr:`state_dict` match the keys returned by this module's
|
||||
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
self.checkpoint_io.load_model(model, checkpoint, strict)
|
||||
self.checkpoint_io.load_model(
|
||||
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_model(
|
||||
self,
|
||||
@ -338,18 +349,25 @@ class Booster:
|
||||
use_async=use_async,
|
||||
)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
|
||||
def load_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
) -> None:
|
||||
"""Load optimizer from checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): An optimizer boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
||||
self.checkpoint_io.load_optimizer(
|
||||
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_optimizer(
|
||||
self,
|
||||
|
@ -1,4 +1,3 @@
|
||||
import gc
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
@ -97,13 +96,22 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
|
||||
def load_unsharded_model(
|
||||
self,
|
||||
model: GeminiDDP,
|
||||
checkpoint: str,
|
||||
strict: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load model from checkpoint with automatic unwrapping.
|
||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||
"""
|
||||
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||
super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||
super().load_unsharded_model(
|
||||
model, checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||
@ -131,13 +139,17 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: GeminiOptimizer, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Loading unsharded optimizer from checkpoint file.
|
||||
For each process, only loading optimizer states of parameters it controls.
|
||||
"""
|
||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
|
||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
||||
super().load_unsharded_optimizer(
|
||||
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
@ -206,13 +218,27 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
)
|
||||
|
||||
def load_sharded_model(
|
||||
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
|
||||
self,
|
||||
model: GeminiDDP,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load shard model, load model from multiple files.
|
||||
"""
|
||||
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
||||
return super().load_sharded_model(
|
||||
model,
|
||||
checkpoint_index_file,
|
||||
strict,
|
||||
use_safetensors,
|
||||
load_sub_module=False,
|
||||
low_cpu_mem_mode=low_cpu_mem_mode,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
@ -289,7 +315,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: GeminiOptimizer,
|
||||
checkpoint_index_file: Path,
|
||||
prefix: str,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Loading sharded optimizer from checkpoint folder, with index file given.
|
||||
For each process, only loading optimizer states of parameters it controls.
|
||||
@ -322,9 +355,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict_shard = load_flat(shard_file)
|
||||
else:
|
||||
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
|
||||
optimizer.load_param_states(state_dict_shard)
|
||||
del state_dict_shard
|
||||
gc.collect()
|
||||
|
||||
optimizer.optimizer_loading_epilogue()
|
||||
|
||||
|
@ -20,6 +20,7 @@ from torch.utils.data import DataLoader
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
create_pinned_state_dict,
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
load_param_groups_into_optimizer,
|
||||
@ -145,7 +146,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
use_async = checkpoint.endswith(".safetensors")
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import load_flat
|
||||
@ -153,6 +156,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
checkpoint = load_flat(checkpoint)
|
||||
else:
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
if not low_cpu_mem_mode:
|
||||
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
@ -239,7 +244,14 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
index_file_path: str,
|
||||
prefix: str,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""Load sharded optimizer with the given path to index file.
|
||||
|
||||
Args:
|
||||
@ -283,14 +295,28 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
v_list = v.split(v.numel() // self.coordinator.world_size)
|
||||
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
|
||||
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
|
||||
if low_cpu_mem_mode:
|
||||
state_dict[param_idx][k] = state_dict[param_idx][k].clone()
|
||||
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
||||
def load_unsharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint: str,
|
||||
strict: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
super().load_unsharded_model(model, checkpoint, strict)
|
||||
super().load_unsharded_model(
|
||||
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
model.update_master_params()
|
||||
|
||||
def load_sharded_model(
|
||||
@ -300,10 +326,20 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
super().load_sharded_model(
|
||||
model,
|
||||
checkpoint_index_file,
|
||||
strict,
|
||||
use_safetensors,
|
||||
load_sub_module,
|
||||
low_cpu_mem_mode=low_cpu_mem_mode,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
model.update_master_params()
|
||||
|
||||
def save_unsharded_model(
|
||||
|
@ -26,12 +26,21 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
self.coordinator = DistCoordinator()
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
||||
def load_unsharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint: str,
|
||||
strict: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load model from checkpoint.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
|
||||
super().load_unsharded_model(
|
||||
model.unwrap(), checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_unsharded_model(
|
||||
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||
@ -45,12 +54,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
|
||||
)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Load optimizer from checkpoint.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
||||
super().load_unsharded_optimizer(
|
||||
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||
@ -101,12 +114,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load model from sharded checkpoint.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
super().load_sharded_model(
|
||||
model.unwrap(),
|
||||
checkpoint_index_file,
|
||||
strict,
|
||||
use_safetensors,
|
||||
load_sub_module,
|
||||
low_cpu_mem_mode=low_cpu_mem_mode,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
@ -131,12 +154,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
optimizer: Optimizer,
|
||||
index_file_path: str,
|
||||
prefix: Optional[str] = None,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load optimizer from sharded checkpoint.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
|
||||
super().load_sharded_optimizer(
|
||||
optimizer.unwrap(), index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
|
@ -43,13 +43,17 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
self.coordinator = DistCoordinator()
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
|
||||
def load_unsharded_model(
|
||||
self, model: ModelWrapper, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
|
||||
model = model.unwrap()
|
||||
checkpoint = utils.load_state_dict(checkpoint)
|
||||
model.load_state_dict(checkpoint)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
if checkpoint.endswith(".safetensors"):
|
||||
checkpoint = load_flat(checkpoint, seperator=".")
|
||||
@ -232,6 +236,8 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load model to checkpoint but only on master process.
|
||||
@ -354,7 +360,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
index_file_path: str,
|
||||
size_per_shard: int,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
|
@ -85,7 +85,12 @@ class CheckpointIO(ABC):
|
||||
self._sync_io()
|
||||
|
||||
def load_model(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
|
||||
self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
strict: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
) -> Union[nn.Module, ModelWrapper]:
|
||||
"""
|
||||
Load model from checkpoint.
|
||||
@ -100,6 +105,8 @@ class CheckpointIO(ABC):
|
||||
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
# since we only support loaded sharded and unsharded weight format
|
||||
# containing no distributed tensors, dtensor -> full tensor conversion
|
||||
@ -111,17 +118,25 @@ class CheckpointIO(ABC):
|
||||
origin_model = model
|
||||
|
||||
if index_file_exists:
|
||||
self.load_sharded_model(model, index_file_path, strict)
|
||||
self.load_sharded_model(
|
||||
model, index_file_path, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
self.load_unsharded_model(
|
||||
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
path = Path(checkpoint, WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
self.load_unsharded_model(
|
||||
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
self.load_unsharded_model(model, checkpoint, strict)
|
||||
self.load_unsharded_model(
|
||||
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
return origin_model
|
||||
|
||||
@ -178,7 +193,14 @@ class CheckpointIO(ABC):
|
||||
else:
|
||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
|
||||
def load_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
prefix: str = None,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load optimizer from checkpoint.
|
||||
|
||||
@ -187,7 +209,8 @@ class CheckpointIO(ABC):
|
||||
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
|
||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||
@ -198,9 +221,13 @@ class CheckpointIO(ABC):
|
||||
|
||||
if index_file_exists:
|
||||
# the existence of index file means it is a sharded checkpoint
|
||||
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
|
||||
self.load_sharded_optimizer(
|
||||
optimizer, index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
self.load_unsharded_optimizer(optimizer, checkpoint)
|
||||
self.load_unsharded_optimizer(
|
||||
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_optimizer(
|
||||
self,
|
||||
@ -238,7 +265,9 @@ class CheckpointIO(ABC):
|
||||
# Abstract methods for model loading/saving implementation
|
||||
# ========================================================
|
||||
@abstractmethod
|
||||
def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool):
|
||||
def load_sharded_model(
|
||||
self, model: nn.Module, index_file_path: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Load model from sharded checkpoint.
|
||||
|
||||
@ -247,10 +276,14 @@ class CheckpointIO(ABC):
|
||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||
def load_unsharded_model(
|
||||
self, model: nn.Module, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Load model from unsharded checkpoint.
|
||||
|
||||
@ -259,6 +292,8 @@ class CheckpointIO(ABC):
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -303,7 +338,14 @@ class CheckpointIO(ABC):
|
||||
# ========================================================
|
||||
|
||||
@abstractmethod
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
index_file_path: str,
|
||||
prefix: str,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load optimizer from sharded checkpoint.
|
||||
|
||||
@ -311,16 +353,22 @@ class CheckpointIO(ABC):
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
prefix (str): prefix for the optimizer checkpoint.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Load optimizer from unsharded checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
@ -1,4 +1,3 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
@ -40,8 +39,17 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
Checkpoint IO
|
||||
"""
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||
def load_unsharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
strict: bool,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
if not low_cpu_mem_mode:
|
||||
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
|
||||
model.load_state_dict(checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_model(
|
||||
@ -60,7 +68,14 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
# save the checkpoint
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
index_file_path: str,
|
||||
prefix: str,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file.
|
||||
"""
|
||||
@ -84,6 +99,8 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
state_dict = load_flat(shard_file)
|
||||
else:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
@ -158,11 +175,15 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
if checkpoint.endswith(".safetensors"):
|
||||
checkpoint = load_flat(checkpoint)
|
||||
else:
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
if not low_cpu_mem_mode:
|
||||
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
|
||||
def save_unsharded_optimizer(
|
||||
@ -256,6 +277,8 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
load shard model, load model from multiple files
|
||||
@ -274,9 +297,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
if strict:
|
||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||
|
@ -355,7 +355,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
|
||||
def load_sharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load sharded model with the given path to index file of checkpoint folder.
|
||||
|
||||
@ -403,6 +410,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
|
||||
load_state_dict_into_model(
|
||||
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
|
||||
@ -632,7 +641,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint_index_file: str,
|
||||
prefix: str = "",
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file of checkpoint folder.
|
||||
|
||||
@ -706,6 +722,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict = load_flat(file_path)
|
||||
else:
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||
loaded_file.add(filename)
|
||||
|
||||
@ -789,7 +807,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
else:
|
||||
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
|
||||
def load_unsharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint: str,
|
||||
strict: bool = False,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load model from a single file with the given path of checkpoint.
|
||||
|
||||
@ -812,6 +837,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
|
||||
# model.load_state_dict can be directly called.
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
# Update master params if mixed-precision training is enabled.
|
||||
@ -912,7 +939,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Load optimizer from a file with given path.
|
||||
|
||||
@ -940,6 +969,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict = load_flat(checkpoint)
|
||||
else:
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
|
||||
# Load param_groups.
|
||||
updated_groups = []
|
||||
|
@ -510,7 +510,14 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint_index_file: str,
|
||||
prefix: str = "",
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file of checkpoint folder.
|
||||
|
||||
@ -795,7 +802,14 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
dist.barrier()
|
||||
|
||||
# Copied from colossalai.moe
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
|
||||
def load_unsharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
strict: bool = False,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load optimizer from a file with given path.
|
||||
|
||||
|
@ -1,18 +1,20 @@
|
||||
# coding=utf-8
|
||||
import concurrent.futures
|
||||
import os
|
||||
import re
|
||||
from collections import abc as container_abcs
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging.version import Version
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
@ -791,7 +793,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
|
||||
if key != "step":
|
||||
if param.is_floating_point():
|
||||
value = value.to(param.dtype)
|
||||
value = value.to(param.device)
|
||||
value = value.to(param.device, non_blocking=True)
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: cast(param, v, key=k) for k, v in value.items()}
|
||||
@ -811,6 +813,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
|
||||
elif not strict:
|
||||
new_states[k] = v
|
||||
|
||||
get_accelerator().synchronize()
|
||||
optimizer.state.update(new_states)
|
||||
|
||||
|
||||
@ -945,8 +948,27 @@ def get_shard_filename(weights_name: str, idx: int):
|
||||
return shard_file
|
||||
|
||||
|
||||
def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]):
|
||||
pin_mem = dict()
|
||||
for name, tensor in state_dict.items():
|
||||
pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu")
|
||||
return pin_mem
|
||||
def _pin_tensor(tensor: torch.Tensor, empty: bool = True) -> torch.Tensor:
|
||||
if empty:
|
||||
return torch.empty_like(tensor, pin_memory=True, device="cpu")
|
||||
return tensor.pin_memory()
|
||||
|
||||
|
||||
def create_pinned_state_dict(
|
||||
state_dict: Union[Dict[str, torch.Tensor], Dict[int, Dict[str, torch.Tensor]]],
|
||||
empty: bool = True,
|
||||
num_threads: int = 1,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if num_threads == 1:
|
||||
return tree_map(lambda x: _pin_tensor(x, empty=empty) if isinstance(x, torch.Tensor) else x, state_dict)
|
||||
else:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
elems, spec = tree_flatten(state_dict)
|
||||
future_to_idx = {}
|
||||
for i, elem in enumerate(elems):
|
||||
if isinstance(elem, torch.Tensor):
|
||||
future_to_idx[executor.submit(_pin_tensor, elem, empty)] = i
|
||||
for future in concurrent.futures.as_completed(future_to_idx):
|
||||
idx = future_to_idx[future]
|
||||
elems[idx] = future.result()
|
||||
return tree_unflatten(elems, spec)
|
||||
|
@ -90,8 +90,16 @@ def exam_state_dict_with_origin(
|
||||
@parameterize("tp_size", [1, 2])
|
||||
@parameterize("zero_size", [2])
|
||||
@parameterize("use_async", [False, True])
|
||||
@parameterize("low_cpu_mem_mode", [True, False])
|
||||
def exam_state_dict(
|
||||
placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool
|
||||
placement_config,
|
||||
shard: bool,
|
||||
model_name: str,
|
||||
size_per_shard: int,
|
||||
tp_size: int,
|
||||
zero_size: int,
|
||||
use_async: bool,
|
||||
low_cpu_mem_mode: bool,
|
||||
):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
@ -147,12 +155,12 @@ def exam_state_dict(
|
||||
booster.checkpoint_io._sync_io()
|
||||
dist.barrier()
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||
check_state_dict_equal(
|
||||
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), ignore_dtype=True
|
||||
)
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False))
|
||||
for group in new_optimizer.param_groups:
|
||||
assert group["lr"] == 0.1
|
||||
|
@ -43,8 +43,11 @@ else:
|
||||
@parameterize("size_per_shard", [32])
|
||||
@parameterize("test_config", TEST_CONFIGS)
|
||||
@parameterize("use_async", [False, True])
|
||||
@parameterize("low_cpu_mem_mode", [False, True])
|
||||
@clear_cache_before_run()
|
||||
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool):
|
||||
def exam_state_dict(
|
||||
shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool, low_cpu_mem_mode: bool
|
||||
):
|
||||
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
|
||||
iter(model_zoo.get_sub_registry(model_name).values())
|
||||
)
|
||||
@ -102,9 +105,9 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict())
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict())
|
||||
dist.barrier()
|
||||
|
||||
|
@ -29,7 +29,8 @@ from tests.kit.model_zoo import model_zoo
|
||||
@parameterize("shard", [False, True])
|
||||
@parameterize("offload", [False, True])
|
||||
@parameterize("use_async", [False, True])
|
||||
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool):
|
||||
@parameterize("low_cpu_mem_mode", [False, True])
|
||||
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool, low_cpu_mem_mode: bool):
|
||||
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
|
||||
booster = Booster(plugin=plugin)
|
||||
model = resnet18()
|
||||
@ -70,7 +71,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
|
||||
new_optimizer = HybridAdam((new_model.parameters()), lr=0.001)
|
||||
new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||
# check master weight
|
||||
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
|
||||
@ -85,7 +86,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
|
||||
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
|
||||
)
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -1,108 +1,144 @@
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from colossalai.checkpoint_io.utils import create_pinned_state_dict
|
||||
from colossalai.testing import check_state_dict_equal, clear_cache_before_run
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
|
||||
|
||||
|
||||
def gen_optim_state_dict():
|
||||
return {
|
||||
"state": {
|
||||
0: {
|
||||
"step": torch.tensor(1.0),
|
||||
"exp_avg": torch.rand((1024, 1024)),
|
||||
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||
},
|
||||
1: {
|
||||
"step": torch.tensor(1.0),
|
||||
"exp_avg": torch.rand((1024, 1024)),
|
||||
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||
},
|
||||
2: {
|
||||
"step": torch.tensor(1.0),
|
||||
"exp_avg": torch.rand((1024, 1024)),
|
||||
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||
},
|
||||
},
|
||||
"param_groups": [
|
||||
{
|
||||
"lr": 0.001,
|
||||
"betas": (0.9, 0.999),
|
||||
"eps": 1e-08,
|
||||
"weight_decay": 0,
|
||||
"bias_correction": True,
|
||||
"params": [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
54,
|
||||
55,
|
||||
56,
|
||||
57,
|
||||
58,
|
||||
59,
|
||||
60,
|
||||
61,
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def gen_model_state_dict():
|
||||
return {
|
||||
"module.weight0": torch.rand((1024, 1024)),
|
||||
"module.weight1": torch.rand((1024, 1024)),
|
||||
"module.weight2": torch.rand((1024, 1024)),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("empty", [True, False])
|
||||
@pytest.mark.parametrize("num_threads", [1, 4])
|
||||
def test_create_pin(empty: bool, num_threads: int):
|
||||
model_state_dict = gen_model_state_dict()
|
||||
model_state_dict_pinned = create_pinned_state_dict(model_state_dict, empty=empty, num_threads=num_threads)
|
||||
for k in model_state_dict.keys():
|
||||
assert model_state_dict_pinned[k].is_pinned()
|
||||
if not empty:
|
||||
assert torch.equal(model_state_dict_pinned[k], model_state_dict[k])
|
||||
optim_state_dict = gen_optim_state_dict()
|
||||
optim_state_dict_pinned = create_pinned_state_dict(optim_state_dict, empty=empty, num_threads=num_threads)
|
||||
for k in optim_state_dict.keys():
|
||||
if k == "state":
|
||||
for idx in optim_state_dict[k].keys():
|
||||
for kk in optim_state_dict[k][idx].keys():
|
||||
assert optim_state_dict_pinned[k][idx][kk].is_pinned()
|
||||
if not empty:
|
||||
assert torch.equal(optim_state_dict_pinned[k][idx][kk], optim_state_dict[k][idx][kk])
|
||||
else:
|
||||
assert optim_state_dict[k] == optim_state_dict_pinned[k]
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def test_save_load():
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
optimizer_state_dict = {
|
||||
"state": {
|
||||
0: {
|
||||
"step": torch.tensor(1.0),
|
||||
"exp_avg": torch.rand((1024, 1024)),
|
||||
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||
},
|
||||
1: {
|
||||
"step": torch.tensor(1.0),
|
||||
"exp_avg": torch.rand((1024, 1024)),
|
||||
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||
},
|
||||
2: {
|
||||
"step": torch.tensor(1.0),
|
||||
"exp_avg": torch.rand((1024, 1024)),
|
||||
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||
},
|
||||
},
|
||||
"param_groups": [
|
||||
{
|
||||
"lr": 0.001,
|
||||
"betas": (0.9, 0.999),
|
||||
"eps": 1e-08,
|
||||
"weight_decay": 0,
|
||||
"bias_correction": True,
|
||||
"params": [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
54,
|
||||
55,
|
||||
56,
|
||||
57,
|
||||
58,
|
||||
59,
|
||||
60,
|
||||
61,
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
optimizer_state_dict = gen_optim_state_dict()
|
||||
|
||||
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
||||
f_writer = save_nested(optimizer_saved_path, optimizer_state_dict)
|
||||
@ -120,11 +156,7 @@ def test_save_load():
|
||||
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
|
||||
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
|
||||
|
||||
model_state_dict = {
|
||||
"module.weight0": torch.rand((1024, 1024)),
|
||||
"module.weight1": torch.rand((1024, 1024)),
|
||||
"module.weight2": torch.rand((1024, 1024)),
|
||||
}
|
||||
model_state_dict = gen_model_state_dict()
|
||||
model_saved_path = f"{tempdir}/save_model.safetensors"
|
||||
f_writer = save(model_saved_path, model_state_dict)
|
||||
f_writer.sync_before_step()
|
||||
|
@ -15,7 +15,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad
|
||||
@parameterize("shard", [False, True])
|
||||
@parameterize("size_per_shard", [16, 128])
|
||||
@parameterize("use_async", [False, True])
|
||||
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bool):
|
||||
@parameterize("low_cpu_mem_mode", [False, True])
|
||||
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bool, low_cpu_mem_mode: bool):
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
model = resnet18()
|
||||
@ -61,10 +62,10 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bo
|
||||
new_model, new_optimizer, lr_scheduler=new_scheduler
|
||||
)
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
|
||||
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict())
|
||||
|
Loading…
Reference in New Issue
Block a user