mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 20:23:41 +00:00
[checkpointio] support non blocking pin load (#6172)
* [checkpointio] support non blocking pin load * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
836992438f
commit
af06d162cf
@ -288,7 +288,14 @@ class Booster:
|
|||||||
|
|
||||||
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
|
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
|
||||||
|
|
||||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
def load_model(
|
||||||
|
self,
|
||||||
|
model: Union[nn.Module, ModelWrapper],
|
||||||
|
checkpoint: str,
|
||||||
|
strict: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
) -> None:
|
||||||
"""Load model from checkpoint.
|
"""Load model from checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -298,8 +305,12 @@ class Booster:
|
|||||||
strict (bool, optional): whether to strictly enforce that the keys
|
strict (bool, optional): whether to strictly enforce that the keys
|
||||||
in :attr:`state_dict` match the keys returned by this module's
|
in :attr:`state_dict` match the keys returned by this module's
|
||||||
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
|
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
|
||||||
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
"""
|
"""
|
||||||
self.checkpoint_io.load_model(model, checkpoint, strict)
|
self.checkpoint_io.load_model(
|
||||||
|
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_model(
|
def save_model(
|
||||||
self,
|
self,
|
||||||
@ -338,18 +349,25 @@ class Booster:
|
|||||||
use_async=use_async,
|
use_async=use_async,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
|
def load_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
checkpoint: str,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
) -> None:
|
||||||
"""Load optimizer from checkpoint.
|
"""Load optimizer from checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): An optimizer boosted by Booster.
|
optimizer (Optimizer): An optimizer boosted by Booster.
|
||||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||||
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
||||||
prefix (str, optional): A prefix added to parameter and buffer
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
names to compose the keys in state_dict. Defaults to None.
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
|
||||||
"""
|
"""
|
||||||
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
self.checkpoint_io.load_optimizer(
|
||||||
|
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_optimizer(
|
def save_optimizer(
|
||||||
self,
|
self,
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import gc
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -97,13 +96,22 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
else:
|
else:
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||||
|
|
||||||
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
|
def load_unsharded_model(
|
||||||
|
self,
|
||||||
|
model: GeminiDDP,
|
||||||
|
checkpoint: str,
|
||||||
|
strict: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load model from checkpoint with automatic unwrapping.
|
Load model from checkpoint with automatic unwrapping.
|
||||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||||
"""
|
"""
|
||||||
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||||
super().load_unsharded_model(model, checkpoint, strict=strict)
|
super().load_unsharded_model(
|
||||||
|
model, checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_unsharded_optimizer(
|
def save_unsharded_optimizer(
|
||||||
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||||
@ -131,13 +139,17 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
else:
|
else:
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
|
def load_unsharded_optimizer(
|
||||||
|
self, optimizer: GeminiOptimizer, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Loading unsharded optimizer from checkpoint file.
|
Loading unsharded optimizer from checkpoint file.
|
||||||
For each process, only loading optimizer states of parameters it controls.
|
For each process, only loading optimizer states of parameters it controls.
|
||||||
"""
|
"""
|
||||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
|
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
|
||||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
super().load_unsharded_optimizer(
|
||||||
|
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_sharded_model(
|
def save_sharded_model(
|
||||||
self,
|
self,
|
||||||
@ -206,13 +218,27 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_model(
|
def load_sharded_model(
|
||||||
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
|
self,
|
||||||
|
model: GeminiDDP,
|
||||||
|
checkpoint_index_file: Path,
|
||||||
|
strict: bool = False,
|
||||||
|
use_safetensors: bool = False,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load shard model, load model from multiple files.
|
Load shard model, load model from multiple files.
|
||||||
"""
|
"""
|
||||||
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
return super().load_sharded_model(
|
||||||
|
model,
|
||||||
|
checkpoint_index_file,
|
||||||
|
strict,
|
||||||
|
use_safetensors,
|
||||||
|
load_sub_module=False,
|
||||||
|
low_cpu_mem_mode=low_cpu_mem_mode,
|
||||||
|
num_threads=num_threads,
|
||||||
|
)
|
||||||
|
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self,
|
self,
|
||||||
@ -289,7 +315,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
ranks=[0],
|
ranks=[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
|
def load_sharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: GeminiOptimizer,
|
||||||
|
checkpoint_index_file: Path,
|
||||||
|
prefix: str,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Loading sharded optimizer from checkpoint folder, with index file given.
|
Loading sharded optimizer from checkpoint folder, with index file given.
|
||||||
For each process, only loading optimizer states of parameters it controls.
|
For each process, only loading optimizer states of parameters it controls.
|
||||||
@ -322,9 +355,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
state_dict_shard = load_flat(shard_file)
|
state_dict_shard = load_flat(shard_file)
|
||||||
else:
|
else:
|
||||||
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
|
||||||
optimizer.load_param_states(state_dict_shard)
|
optimizer.load_param_states(state_dict_shard)
|
||||||
del state_dict_shard
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
optimizer.optimizer_loading_epilogue()
|
optimizer.optimizer_loading_epilogue()
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ from torch.utils.data import DataLoader
|
|||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
|
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
|
||||||
from colossalai.checkpoint_io.utils import (
|
from colossalai.checkpoint_io.utils import (
|
||||||
|
create_pinned_state_dict,
|
||||||
get_optimizer_base_filenames,
|
get_optimizer_base_filenames,
|
||||||
get_shard_filename,
|
get_shard_filename,
|
||||||
load_param_groups_into_optimizer,
|
load_param_groups_into_optimizer,
|
||||||
@ -145,7 +146,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
else:
|
else:
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
def load_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
use_async = checkpoint.endswith(".safetensors")
|
use_async = checkpoint.endswith(".safetensors")
|
||||||
if use_async:
|
if use_async:
|
||||||
from colossalai.utils.safetensors import load_flat
|
from colossalai.utils.safetensors import load_flat
|
||||||
@ -153,6 +156,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
checkpoint = load_flat(checkpoint)
|
checkpoint = load_flat(checkpoint)
|
||||||
else:
|
else:
|
||||||
checkpoint = load_state_dict(checkpoint)
|
checkpoint = load_state_dict(checkpoint)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
|
||||||
optimizer.load_state_dict(checkpoint)
|
optimizer.load_state_dict(checkpoint)
|
||||||
|
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
@ -239,7 +244,14 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
ranks=[0],
|
ranks=[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
|
def load_sharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: OptimizerWrapper,
|
||||||
|
index_file_path: str,
|
||||||
|
prefix: str,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""Load sharded optimizer with the given path to index file.
|
"""Load sharded optimizer with the given path to index file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -283,14 +295,28 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
if padding_size > 0:
|
if padding_size > 0:
|
||||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||||
v_list = v.split(v.numel() // self.coordinator.world_size)
|
v_list = v.split(v.numel() // self.coordinator.world_size)
|
||||||
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
|
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
|
||||||
|
if low_cpu_mem_mode:
|
||||||
|
state_dict[param_idx][k] = state_dict[param_idx][k].clone()
|
||||||
|
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||||
sharded_optimizer_loading_epilogue(optimizer)
|
sharded_optimizer_loading_epilogue(optimizer)
|
||||||
|
|
||||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
def load_unsharded_model(
|
||||||
|
self,
|
||||||
|
model: ModelWrapper,
|
||||||
|
checkpoint: str,
|
||||||
|
strict: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||||
model._force_wait_all_gather()
|
model._force_wait_all_gather()
|
||||||
super().load_unsharded_model(model, checkpoint, strict)
|
super().load_unsharded_model(
|
||||||
|
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
model.update_master_params()
|
model.update_master_params()
|
||||||
|
|
||||||
def load_sharded_model(
|
def load_sharded_model(
|
||||||
@ -300,10 +326,20 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
load_sub_module: bool = True,
|
load_sub_module: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
):
|
):
|
||||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||||
model._force_wait_all_gather()
|
model._force_wait_all_gather()
|
||||||
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
super().load_sharded_model(
|
||||||
|
model,
|
||||||
|
checkpoint_index_file,
|
||||||
|
strict,
|
||||||
|
use_safetensors,
|
||||||
|
load_sub_module,
|
||||||
|
low_cpu_mem_mode=low_cpu_mem_mode,
|
||||||
|
num_threads=num_threads,
|
||||||
|
)
|
||||||
model.update_master_params()
|
model.update_master_params()
|
||||||
|
|
||||||
def save_unsharded_model(
|
def save_unsharded_model(
|
||||||
|
@ -26,12 +26,21 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
self.coordinator = DistCoordinator()
|
self.coordinator = DistCoordinator()
|
||||||
self.logger = get_dist_logger()
|
self.logger = get_dist_logger()
|
||||||
|
|
||||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
def load_unsharded_model(
|
||||||
|
self,
|
||||||
|
model: ModelWrapper,
|
||||||
|
checkpoint: str,
|
||||||
|
strict: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load model from checkpoint.
|
Load model from checkpoint.
|
||||||
"""
|
"""
|
||||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||||
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
|
super().load_unsharded_model(
|
||||||
|
model.unwrap(), checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_unsharded_model(
|
def save_unsharded_model(
|
||||||
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||||
@ -45,12 +54,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
|
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
def load_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load optimizer from checkpoint.
|
Load optimizer from checkpoint.
|
||||||
"""
|
"""
|
||||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
super().load_unsharded_optimizer(
|
||||||
|
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_unsharded_optimizer(
|
def save_unsharded_optimizer(
|
||||||
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||||
@ -101,12 +114,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
load_sub_module: bool = True,
|
load_sub_module: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load model from sharded checkpoint.
|
Load model from sharded checkpoint.
|
||||||
"""
|
"""
|
||||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||||
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
super().load_sharded_model(
|
||||||
|
model.unwrap(),
|
||||||
|
checkpoint_index_file,
|
||||||
|
strict,
|
||||||
|
use_safetensors,
|
||||||
|
load_sub_module,
|
||||||
|
low_cpu_mem_mode=low_cpu_mem_mode,
|
||||||
|
num_threads=num_threads,
|
||||||
|
)
|
||||||
|
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self,
|
self,
|
||||||
@ -131,12 +154,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
index_file_path: str,
|
index_file_path: str,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load optimizer from sharded checkpoint.
|
Load optimizer from sharded checkpoint.
|
||||||
"""
|
"""
|
||||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||||
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
|
super().load_sharded_optimizer(
|
||||||
|
optimizer.unwrap(), index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_lora_as_pretrained(
|
def save_lora_as_pretrained(
|
||||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||||
|
@ -43,13 +43,17 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
self.coordinator = DistCoordinator()
|
self.coordinator = DistCoordinator()
|
||||||
self.logger = get_dist_logger()
|
self.logger = get_dist_logger()
|
||||||
|
|
||||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
|
def load_unsharded_model(
|
||||||
|
self, model: ModelWrapper, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
|
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
|
||||||
model = model.unwrap()
|
model = model.unwrap()
|
||||||
checkpoint = utils.load_state_dict(checkpoint)
|
checkpoint = utils.load_state_dict(checkpoint)
|
||||||
model.load_state_dict(checkpoint)
|
model.load_state_dict(checkpoint)
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
|
def load_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
|
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
|
||||||
if checkpoint.endswith(".safetensors"):
|
if checkpoint.endswith(".safetensors"):
|
||||||
checkpoint = load_flat(checkpoint, seperator=".")
|
checkpoint = load_flat(checkpoint, seperator=".")
|
||||||
@ -232,6 +236,8 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
load_sub_module: bool = True,
|
load_sub_module: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load model to checkpoint but only on master process.
|
Load model to checkpoint but only on master process.
|
||||||
@ -354,7 +360,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
|
def load_sharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
index_file_path: str,
|
||||||
|
size_per_shard: int,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load optimizer to checkpoint but only on master process.
|
Load optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
|
@ -85,7 +85,12 @@ class CheckpointIO(ABC):
|
|||||||
self._sync_io()
|
self._sync_io()
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
|
self,
|
||||||
|
model: Union[nn.Module, ModelWrapper],
|
||||||
|
checkpoint: str,
|
||||||
|
strict: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
) -> Union[nn.Module, ModelWrapper]:
|
) -> Union[nn.Module, ModelWrapper]:
|
||||||
"""
|
"""
|
||||||
Load model from checkpoint.
|
Load model from checkpoint.
|
||||||
@ -100,6 +105,8 @@ class CheckpointIO(ABC):
|
|||||||
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
|
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
|
||||||
strict (bool): whether to strictly enforce that the param name in
|
strict (bool): whether to strictly enforce that the param name in
|
||||||
the checkpoint match the keys returned by this module's.
|
the checkpoint match the keys returned by this module's.
|
||||||
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
"""
|
"""
|
||||||
# since we only support loaded sharded and unsharded weight format
|
# since we only support loaded sharded and unsharded weight format
|
||||||
# containing no distributed tensors, dtensor -> full tensor conversion
|
# containing no distributed tensors, dtensor -> full tensor conversion
|
||||||
@ -111,17 +118,25 @@ class CheckpointIO(ABC):
|
|||||||
origin_model = model
|
origin_model = model
|
||||||
|
|
||||||
if index_file_exists:
|
if index_file_exists:
|
||||||
self.load_sharded_model(model, index_file_path, strict)
|
self.load_sharded_model(
|
||||||
|
model, index_file_path, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
|
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
|
||||||
if path.is_file():
|
if path.is_file():
|
||||||
self.load_unsharded_model(model, str(path), strict)
|
self.load_unsharded_model(
|
||||||
|
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
path = Path(checkpoint, WEIGHTS_NAME)
|
path = Path(checkpoint, WEIGHTS_NAME)
|
||||||
if path.is_file():
|
if path.is_file():
|
||||||
self.load_unsharded_model(model, str(path), strict)
|
self.load_unsharded_model(
|
||||||
|
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.load_unsharded_model(model, checkpoint, strict)
|
self.load_unsharded_model(
|
||||||
|
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
return origin_model
|
return origin_model
|
||||||
|
|
||||||
@ -178,7 +193,14 @@ class CheckpointIO(ABC):
|
|||||||
else:
|
else:
|
||||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
|
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
|
||||||
|
|
||||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
|
def load_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
checkpoint: str,
|
||||||
|
prefix: str = None,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load optimizer from checkpoint.
|
Load optimizer from checkpoint.
|
||||||
|
|
||||||
@ -187,7 +209,8 @@ class CheckpointIO(ABC):
|
|||||||
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
||||||
prefix (str, optional): A prefix added to parameter and buffer
|
prefix (str, optional): A prefix added to parameter and buffer
|
||||||
names to compose the keys in state_dict. Defaults to None.
|
names to compose the keys in state_dict. Defaults to None.
|
||||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||||
@ -198,9 +221,13 @@ class CheckpointIO(ABC):
|
|||||||
|
|
||||||
if index_file_exists:
|
if index_file_exists:
|
||||||
# the existence of index file means it is a sharded checkpoint
|
# the existence of index file means it is a sharded checkpoint
|
||||||
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
|
self.load_sharded_optimizer(
|
||||||
|
optimizer, index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.load_unsharded_optimizer(optimizer, checkpoint)
|
self.load_unsharded_optimizer(
|
||||||
|
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_optimizer(
|
def save_optimizer(
|
||||||
self,
|
self,
|
||||||
@ -238,7 +265,9 @@ class CheckpointIO(ABC):
|
|||||||
# Abstract methods for model loading/saving implementation
|
# Abstract methods for model loading/saving implementation
|
||||||
# ========================================================
|
# ========================================================
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool):
|
def load_sharded_model(
|
||||||
|
self, model: nn.Module, index_file_path: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load model from sharded checkpoint.
|
Load model from sharded checkpoint.
|
||||||
|
|
||||||
@ -247,10 +276,14 @@ class CheckpointIO(ABC):
|
|||||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||||
strict (bool): whether to strictly enforce that the param name in
|
strict (bool): whether to strictly enforce that the param name in
|
||||||
the checkpoint match the keys returned by this module's.
|
the checkpoint match the keys returned by this module's.
|
||||||
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
def load_unsharded_model(
|
||||||
|
self, model: nn.Module, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load model from unsharded checkpoint.
|
Load model from unsharded checkpoint.
|
||||||
|
|
||||||
@ -259,6 +292,8 @@ class CheckpointIO(ABC):
|
|||||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||||
strict (bool): whether to strictly enforce that the param name in
|
strict (bool): whether to strictly enforce that the param name in
|
||||||
the checkpoint match the keys returned by this module's.
|
the checkpoint match the keys returned by this module's.
|
||||||
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -303,7 +338,14 @@ class CheckpointIO(ABC):
|
|||||||
# ========================================================
|
# ========================================================
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
def load_sharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
index_file_path: str,
|
||||||
|
prefix: str,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load optimizer from sharded checkpoint.
|
Load optimizer from sharded checkpoint.
|
||||||
|
|
||||||
@ -311,16 +353,22 @@ class CheckpointIO(ABC):
|
|||||||
optimizer (Optimizer): optimizer to be loaded.
|
optimizer (Optimizer): optimizer to be loaded.
|
||||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||||
prefix (str): prefix for the optimizer checkpoint.
|
prefix (str): prefix for the optimizer checkpoint.
|
||||||
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
def load_unsharded_optimizer(
|
||||||
|
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load optimizer from unsharded checkpoint.
|
Load optimizer from unsharded checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): optimizer to be loaded.
|
optimizer (Optimizer): optimizer to be loaded.
|
||||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||||
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import gc
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
@ -40,8 +39,17 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
Checkpoint IO
|
Checkpoint IO
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
def load_unsharded_model(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
checkpoint: str,
|
||||||
|
strict: bool,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
checkpoint = load_state_dict(checkpoint)
|
checkpoint = load_state_dict(checkpoint)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
|
||||||
model.load_state_dict(checkpoint, strict=strict)
|
model.load_state_dict(checkpoint, strict=strict)
|
||||||
|
|
||||||
def save_unsharded_model(
|
def save_unsharded_model(
|
||||||
@ -60,7 +68,14 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
# save the checkpoint
|
# save the checkpoint
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
def load_sharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
index_file_path: str,
|
||||||
|
prefix: str,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load sharded optimizer with the given path to index file.
|
Load sharded optimizer with the given path to index file.
|
||||||
"""
|
"""
|
||||||
@ -84,6 +99,8 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
state_dict = load_flat(shard_file)
|
state_dict = load_flat(shard_file)
|
||||||
else:
|
else:
|
||||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||||
|
|
||||||
sharded_optimizer_loading_epilogue(optimizer)
|
sharded_optimizer_loading_epilogue(optimizer)
|
||||||
@ -158,11 +175,15 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
def load_unsharded_optimizer(
|
||||||
|
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
if checkpoint.endswith(".safetensors"):
|
if checkpoint.endswith(".safetensors"):
|
||||||
checkpoint = load_flat(checkpoint)
|
checkpoint = load_flat(checkpoint)
|
||||||
else:
|
else:
|
||||||
checkpoint = load_state_dict(checkpoint)
|
checkpoint = load_state_dict(checkpoint)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
|
||||||
optimizer.load_state_dict(checkpoint)
|
optimizer.load_state_dict(checkpoint)
|
||||||
|
|
||||||
def save_unsharded_optimizer(
|
def save_unsharded_optimizer(
|
||||||
@ -256,6 +277,8 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
load_sub_module: bool = True,
|
load_sub_module: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
load shard model, load model from multiple files
|
load shard model, load model from multiple files
|
||||||
@ -274,9 +297,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
|
|
||||||
for shard_file in checkpoint_files:
|
for shard_file in checkpoint_files:
|
||||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
|
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
|
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
|
||||||
del state_dict
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
if strict:
|
if strict:
|
||||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||||
|
@ -355,7 +355,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
f"index located at {final_index_file_path}."
|
f"index located at {final_index_file_path}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
|
def load_sharded_model(
|
||||||
|
self,
|
||||||
|
model: ModelWrapper,
|
||||||
|
checkpoint_index_file: Path,
|
||||||
|
strict: bool = False,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load sharded model with the given path to index file of checkpoint folder.
|
Load sharded model with the given path to index file of checkpoint folder.
|
||||||
|
|
||||||
@ -403,6 +410,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
|
|
||||||
file_path = os.path.join(ckpt_root_path, filename)
|
file_path = os.path.join(ckpt_root_path, filename)
|
||||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
|
|
||||||
load_state_dict_into_model(
|
load_state_dict_into_model(
|
||||||
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
|
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
|
||||||
@ -632,7 +641,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
f"index located at {final_index_file_path}."
|
f"index located at {final_index_file_path}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
def load_sharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: OptimizerWrapper,
|
||||||
|
checkpoint_index_file: str,
|
||||||
|
prefix: str = "",
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load sharded optimizer with the given path to index file of checkpoint folder.
|
Load sharded optimizer with the given path to index file of checkpoint folder.
|
||||||
|
|
||||||
@ -706,6 +722,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
state_dict = load_flat(file_path)
|
state_dict = load_flat(file_path)
|
||||||
else:
|
else:
|
||||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||||
loaded_file.add(filename)
|
loaded_file.add(filename)
|
||||||
|
|
||||||
@ -789,7 +807,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
else:
|
else:
|
||||||
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
||||||
|
|
||||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
|
def load_unsharded_model(
|
||||||
|
self,
|
||||||
|
model: ModelWrapper,
|
||||||
|
checkpoint: str,
|
||||||
|
strict: bool = False,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load model from a single file with the given path of checkpoint.
|
Load model from a single file with the given path of checkpoint.
|
||||||
|
|
||||||
@ -812,6 +837,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
# has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
|
# has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
|
||||||
# model.load_state_dict can be directly called.
|
# model.load_state_dict can be directly called.
|
||||||
state_dict = load_state_dict(checkpoint)
|
state_dict = load_state_dict(checkpoint)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
model.load_state_dict(state_dict, strict=strict)
|
model.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
# Update master params if mixed-precision training is enabled.
|
# Update master params if mixed-precision training is enabled.
|
||||||
@ -912,7 +939,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
else:
|
else:
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
def load_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load optimizer from a file with given path.
|
Load optimizer from a file with given path.
|
||||||
|
|
||||||
@ -940,6 +969,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
state_dict = load_flat(checkpoint)
|
state_dict = load_flat(checkpoint)
|
||||||
else:
|
else:
|
||||||
state_dict = load_state_dict(checkpoint)
|
state_dict = load_state_dict(checkpoint)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
|
|
||||||
# Load param_groups.
|
# Load param_groups.
|
||||||
updated_groups = []
|
updated_groups = []
|
||||||
|
@ -510,7 +510,14 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
|||||||
f"index located at {final_index_file_path}."
|
f"index located at {final_index_file_path}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
def load_sharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: OptimizerWrapper,
|
||||||
|
checkpoint_index_file: str,
|
||||||
|
prefix: str = "",
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load sharded optimizer with the given path to index file of checkpoint folder.
|
Load sharded optimizer with the given path to index file of checkpoint folder.
|
||||||
|
|
||||||
@ -795,7 +802,14 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
|||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
# Copied from colossalai.moe
|
# Copied from colossalai.moe
|
||||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
|
def load_unsharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: OptimizerWrapper,
|
||||||
|
checkpoint: str,
|
||||||
|
strict: bool = False,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load optimizer from a file with given path.
|
Load optimizer from a file with given path.
|
||||||
|
|
||||||
|
@ -1,18 +1,20 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
|
import concurrent.futures
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from collections import abc as container_abcs
|
from collections import abc as container_abcs
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||||
|
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.tensor.d_tensor import (
|
from colossalai.tensor.d_tensor import (
|
||||||
is_customized_distributed_tensor,
|
is_customized_distributed_tensor,
|
||||||
is_distributed_tensor,
|
is_distributed_tensor,
|
||||||
@ -791,7 +793,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
|
|||||||
if key != "step":
|
if key != "step":
|
||||||
if param.is_floating_point():
|
if param.is_floating_point():
|
||||||
value = value.to(param.dtype)
|
value = value.to(param.dtype)
|
||||||
value = value.to(param.device)
|
value = value.to(param.device, non_blocking=True)
|
||||||
return value
|
return value
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
return {k: cast(param, v, key=k) for k, v in value.items()}
|
return {k: cast(param, v, key=k) for k, v in value.items()}
|
||||||
@ -811,6 +813,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
|
|||||||
elif not strict:
|
elif not strict:
|
||||||
new_states[k] = v
|
new_states[k] = v
|
||||||
|
|
||||||
|
get_accelerator().synchronize()
|
||||||
optimizer.state.update(new_states)
|
optimizer.state.update(new_states)
|
||||||
|
|
||||||
|
|
||||||
@ -945,8 +948,27 @@ def get_shard_filename(weights_name: str, idx: int):
|
|||||||
return shard_file
|
return shard_file
|
||||||
|
|
||||||
|
|
||||||
def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]):
|
def _pin_tensor(tensor: torch.Tensor, empty: bool = True) -> torch.Tensor:
|
||||||
pin_mem = dict()
|
if empty:
|
||||||
for name, tensor in state_dict.items():
|
return torch.empty_like(tensor, pin_memory=True, device="cpu")
|
||||||
pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu")
|
return tensor.pin_memory()
|
||||||
return pin_mem
|
|
||||||
|
|
||||||
|
def create_pinned_state_dict(
|
||||||
|
state_dict: Union[Dict[str, torch.Tensor], Dict[int, Dict[str, torch.Tensor]]],
|
||||||
|
empty: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
if num_threads == 1:
|
||||||
|
return tree_map(lambda x: _pin_tensor(x, empty=empty) if isinstance(x, torch.Tensor) else x, state_dict)
|
||||||
|
else:
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||||
|
elems, spec = tree_flatten(state_dict)
|
||||||
|
future_to_idx = {}
|
||||||
|
for i, elem in enumerate(elems):
|
||||||
|
if isinstance(elem, torch.Tensor):
|
||||||
|
future_to_idx[executor.submit(_pin_tensor, elem, empty)] = i
|
||||||
|
for future in concurrent.futures.as_completed(future_to_idx):
|
||||||
|
idx = future_to_idx[future]
|
||||||
|
elems[idx] = future.result()
|
||||||
|
return tree_unflatten(elems, spec)
|
||||||
|
@ -90,8 +90,16 @@ def exam_state_dict_with_origin(
|
|||||||
@parameterize("tp_size", [1, 2])
|
@parameterize("tp_size", [1, 2])
|
||||||
@parameterize("zero_size", [2])
|
@parameterize("zero_size", [2])
|
||||||
@parameterize("use_async", [False, True])
|
@parameterize("use_async", [False, True])
|
||||||
|
@parameterize("low_cpu_mem_mode", [True, False])
|
||||||
def exam_state_dict(
|
def exam_state_dict(
|
||||||
placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool
|
placement_config,
|
||||||
|
shard: bool,
|
||||||
|
model_name: str,
|
||||||
|
size_per_shard: int,
|
||||||
|
tp_size: int,
|
||||||
|
zero_size: int,
|
||||||
|
use_async: bool,
|
||||||
|
low_cpu_mem_mode: bool,
|
||||||
):
|
):
|
||||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||||
criterion = lambda x: x.mean()
|
criterion = lambda x: x.mean()
|
||||||
@ -147,12 +155,12 @@ def exam_state_dict(
|
|||||||
booster.checkpoint_io._sync_io()
|
booster.checkpoint_io._sync_io()
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
booster.load_model(new_model, model_ckpt_path)
|
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||||
check_state_dict_equal(
|
check_state_dict_equal(
|
||||||
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), ignore_dtype=True
|
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), ignore_dtype=True
|
||||||
)
|
)
|
||||||
|
|
||||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||||
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False))
|
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False))
|
||||||
for group in new_optimizer.param_groups:
|
for group in new_optimizer.param_groups:
|
||||||
assert group["lr"] == 0.1
|
assert group["lr"] == 0.1
|
||||||
|
@ -43,8 +43,11 @@ else:
|
|||||||
@parameterize("size_per_shard", [32])
|
@parameterize("size_per_shard", [32])
|
||||||
@parameterize("test_config", TEST_CONFIGS)
|
@parameterize("test_config", TEST_CONFIGS)
|
||||||
@parameterize("use_async", [False, True])
|
@parameterize("use_async", [False, True])
|
||||||
|
@parameterize("low_cpu_mem_mode", [False, True])
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool):
|
def exam_state_dict(
|
||||||
|
shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool, low_cpu_mem_mode: bool
|
||||||
|
):
|
||||||
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
|
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
|
||||||
iter(model_zoo.get_sub_registry(model_name).values())
|
iter(model_zoo.get_sub_registry(model_name).values())
|
||||||
)
|
)
|
||||||
@ -102,9 +105,9 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
|||||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||||
|
|
||||||
booster.load_model(new_model, model_ckpt_path)
|
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||||
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict())
|
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict())
|
||||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||||
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict())
|
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict())
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
@ -29,7 +29,8 @@ from tests.kit.model_zoo import model_zoo
|
|||||||
@parameterize("shard", [False, True])
|
@parameterize("shard", [False, True])
|
||||||
@parameterize("offload", [False, True])
|
@parameterize("offload", [False, True])
|
||||||
@parameterize("use_async", [False, True])
|
@parameterize("use_async", [False, True])
|
||||||
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool):
|
@parameterize("low_cpu_mem_mode", [False, True])
|
||||||
|
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool, low_cpu_mem_mode: bool):
|
||||||
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
|
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
model = resnet18()
|
model = resnet18()
|
||||||
@ -70,7 +71,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
|
|||||||
new_optimizer = HybridAdam((new_model.parameters()), lr=0.001)
|
new_optimizer = HybridAdam((new_model.parameters()), lr=0.001)
|
||||||
new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)
|
new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)
|
||||||
|
|
||||||
booster.load_model(new_model, model_ckpt_path)
|
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||||
# check master weight
|
# check master weight
|
||||||
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
|
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
|
||||||
@ -85,7 +86,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
|
|||||||
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
|
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
|
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -1,108 +1,144 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
from colossalai.checkpoint_io.utils import create_pinned_state_dict
|
||||||
from colossalai.testing import check_state_dict_equal, clear_cache_before_run
|
from colossalai.testing import check_state_dict_equal, clear_cache_before_run
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
|
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
|
||||||
|
|
||||||
|
|
||||||
|
def gen_optim_state_dict():
|
||||||
|
return {
|
||||||
|
"state": {
|
||||||
|
0: {
|
||||||
|
"step": torch.tensor(1.0),
|
||||||
|
"exp_avg": torch.rand((1024, 1024)),
|
||||||
|
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||||
|
},
|
||||||
|
1: {
|
||||||
|
"step": torch.tensor(1.0),
|
||||||
|
"exp_avg": torch.rand((1024, 1024)),
|
||||||
|
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||||
|
},
|
||||||
|
2: {
|
||||||
|
"step": torch.tensor(1.0),
|
||||||
|
"exp_avg": torch.rand((1024, 1024)),
|
||||||
|
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"param_groups": [
|
||||||
|
{
|
||||||
|
"lr": 0.001,
|
||||||
|
"betas": (0.9, 0.999),
|
||||||
|
"eps": 1e-08,
|
||||||
|
"weight_decay": 0,
|
||||||
|
"bias_correction": True,
|
||||||
|
"params": [
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
13,
|
||||||
|
14,
|
||||||
|
15,
|
||||||
|
16,
|
||||||
|
17,
|
||||||
|
18,
|
||||||
|
19,
|
||||||
|
20,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
23,
|
||||||
|
24,
|
||||||
|
25,
|
||||||
|
26,
|
||||||
|
27,
|
||||||
|
28,
|
||||||
|
29,
|
||||||
|
30,
|
||||||
|
31,
|
||||||
|
32,
|
||||||
|
33,
|
||||||
|
34,
|
||||||
|
35,
|
||||||
|
36,
|
||||||
|
37,
|
||||||
|
38,
|
||||||
|
39,
|
||||||
|
40,
|
||||||
|
41,
|
||||||
|
42,
|
||||||
|
43,
|
||||||
|
44,
|
||||||
|
45,
|
||||||
|
46,
|
||||||
|
47,
|
||||||
|
48,
|
||||||
|
49,
|
||||||
|
50,
|
||||||
|
51,
|
||||||
|
52,
|
||||||
|
53,
|
||||||
|
54,
|
||||||
|
55,
|
||||||
|
56,
|
||||||
|
57,
|
||||||
|
58,
|
||||||
|
59,
|
||||||
|
60,
|
||||||
|
61,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def gen_model_state_dict():
|
||||||
|
return {
|
||||||
|
"module.weight0": torch.rand((1024, 1024)),
|
||||||
|
"module.weight1": torch.rand((1024, 1024)),
|
||||||
|
"module.weight2": torch.rand((1024, 1024)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("empty", [True, False])
|
||||||
|
@pytest.mark.parametrize("num_threads", [1, 4])
|
||||||
|
def test_create_pin(empty: bool, num_threads: int):
|
||||||
|
model_state_dict = gen_model_state_dict()
|
||||||
|
model_state_dict_pinned = create_pinned_state_dict(model_state_dict, empty=empty, num_threads=num_threads)
|
||||||
|
for k in model_state_dict.keys():
|
||||||
|
assert model_state_dict_pinned[k].is_pinned()
|
||||||
|
if not empty:
|
||||||
|
assert torch.equal(model_state_dict_pinned[k], model_state_dict[k])
|
||||||
|
optim_state_dict = gen_optim_state_dict()
|
||||||
|
optim_state_dict_pinned = create_pinned_state_dict(optim_state_dict, empty=empty, num_threads=num_threads)
|
||||||
|
for k in optim_state_dict.keys():
|
||||||
|
if k == "state":
|
||||||
|
for idx in optim_state_dict[k].keys():
|
||||||
|
for kk in optim_state_dict[k][idx].keys():
|
||||||
|
assert optim_state_dict_pinned[k][idx][kk].is_pinned()
|
||||||
|
if not empty:
|
||||||
|
assert torch.equal(optim_state_dict_pinned[k][idx][kk], optim_state_dict[k][idx][kk])
|
||||||
|
else:
|
||||||
|
assert optim_state_dict[k] == optim_state_dict_pinned[k]
|
||||||
|
|
||||||
|
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def test_save_load():
|
def test_save_load():
|
||||||
with tempfile.TemporaryDirectory() as tempdir:
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
optimizer_state_dict = {
|
optimizer_state_dict = gen_optim_state_dict()
|
||||||
"state": {
|
|
||||||
0: {
|
|
||||||
"step": torch.tensor(1.0),
|
|
||||||
"exp_avg": torch.rand((1024, 1024)),
|
|
||||||
"exp_avg_sq": torch.rand((1024, 1024)),
|
|
||||||
},
|
|
||||||
1: {
|
|
||||||
"step": torch.tensor(1.0),
|
|
||||||
"exp_avg": torch.rand((1024, 1024)),
|
|
||||||
"exp_avg_sq": torch.rand((1024, 1024)),
|
|
||||||
},
|
|
||||||
2: {
|
|
||||||
"step": torch.tensor(1.0),
|
|
||||||
"exp_avg": torch.rand((1024, 1024)),
|
|
||||||
"exp_avg_sq": torch.rand((1024, 1024)),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"param_groups": [
|
|
||||||
{
|
|
||||||
"lr": 0.001,
|
|
||||||
"betas": (0.9, 0.999),
|
|
||||||
"eps": 1e-08,
|
|
||||||
"weight_decay": 0,
|
|
||||||
"bias_correction": True,
|
|
||||||
"params": [
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
3,
|
|
||||||
4,
|
|
||||||
5,
|
|
||||||
6,
|
|
||||||
7,
|
|
||||||
8,
|
|
||||||
9,
|
|
||||||
10,
|
|
||||||
11,
|
|
||||||
12,
|
|
||||||
13,
|
|
||||||
14,
|
|
||||||
15,
|
|
||||||
16,
|
|
||||||
17,
|
|
||||||
18,
|
|
||||||
19,
|
|
||||||
20,
|
|
||||||
21,
|
|
||||||
22,
|
|
||||||
23,
|
|
||||||
24,
|
|
||||||
25,
|
|
||||||
26,
|
|
||||||
27,
|
|
||||||
28,
|
|
||||||
29,
|
|
||||||
30,
|
|
||||||
31,
|
|
||||||
32,
|
|
||||||
33,
|
|
||||||
34,
|
|
||||||
35,
|
|
||||||
36,
|
|
||||||
37,
|
|
||||||
38,
|
|
||||||
39,
|
|
||||||
40,
|
|
||||||
41,
|
|
||||||
42,
|
|
||||||
43,
|
|
||||||
44,
|
|
||||||
45,
|
|
||||||
46,
|
|
||||||
47,
|
|
||||||
48,
|
|
||||||
49,
|
|
||||||
50,
|
|
||||||
51,
|
|
||||||
52,
|
|
||||||
53,
|
|
||||||
54,
|
|
||||||
55,
|
|
||||||
56,
|
|
||||||
57,
|
|
||||||
58,
|
|
||||||
59,
|
|
||||||
60,
|
|
||||||
61,
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
||||||
f_writer = save_nested(optimizer_saved_path, optimizer_state_dict)
|
f_writer = save_nested(optimizer_saved_path, optimizer_state_dict)
|
||||||
@ -120,11 +156,7 @@ def test_save_load():
|
|||||||
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
|
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
|
||||||
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
|
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
|
||||||
|
|
||||||
model_state_dict = {
|
model_state_dict = gen_model_state_dict()
|
||||||
"module.weight0": torch.rand((1024, 1024)),
|
|
||||||
"module.weight1": torch.rand((1024, 1024)),
|
|
||||||
"module.weight2": torch.rand((1024, 1024)),
|
|
||||||
}
|
|
||||||
model_saved_path = f"{tempdir}/save_model.safetensors"
|
model_saved_path = f"{tempdir}/save_model.safetensors"
|
||||||
f_writer = save(model_saved_path, model_state_dict)
|
f_writer = save(model_saved_path, model_state_dict)
|
||||||
f_writer.sync_before_step()
|
f_writer.sync_before_step()
|
||||||
|
@ -15,7 +15,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad
|
|||||||
@parameterize("shard", [False, True])
|
@parameterize("shard", [False, True])
|
||||||
@parameterize("size_per_shard", [16, 128])
|
@parameterize("size_per_shard", [16, 128])
|
||||||
@parameterize("use_async", [False, True])
|
@parameterize("use_async", [False, True])
|
||||||
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bool):
|
@parameterize("low_cpu_mem_mode", [False, True])
|
||||||
|
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bool, low_cpu_mem_mode: bool):
|
||||||
plugin = TorchDDPPlugin()
|
plugin = TorchDDPPlugin()
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
model = resnet18()
|
model = resnet18()
|
||||||
@ -61,10 +62,10 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bo
|
|||||||
new_model, new_optimizer, lr_scheduler=new_scheduler
|
new_model, new_optimizer, lr_scheduler=new_scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
booster.load_model(new_model, model_ckpt_path)
|
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||||
|
|
||||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||||
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
|
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
|
||||||
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict())
|
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict())
|
||||||
|
Loading…
Reference in New Issue
Block a user