[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -3,4 +3,4 @@ from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from .index_file import CheckpointIndexFile
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']
__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"]

View File

@@ -11,7 +11,7 @@ from colossalai.interface import ModelWrapper
from .utils import has_index_file
__all__ = ['CheckpointIO']
__all__ = ["CheckpointIO"]
class CheckpointIO(ABC):
@@ -61,10 +61,9 @@ class CheckpointIO(ABC):
# ======================================
# Public methods
# ======================================
def load_model(self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
strict: bool = True) -> Union[nn.Module, ModelWrapper]:
def load_model(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
) -> Union[nn.Module, ModelWrapper]:
"""
Load model from checkpoint.
@@ -98,14 +97,16 @@ class CheckpointIO(ABC):
return origin_model
def save_model(self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: str = None,
size_per_shard: int = 1024,
use_safetensors: bool = False):
def save_model(
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: str = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
):
"""
Save model to checkpoint.
@@ -157,7 +158,7 @@ class CheckpointIO(ABC):
if Path(checkpoint).is_dir() and not index_file_exists:
# if the checkpoint is a directory and there is no index file, raise error
raise ValueError(f'Cannot find index file in {checkpoint}')
raise ValueError(f"Cannot find index file in {checkpoint}")
if index_file_exists:
# the existence of index file means it is a sharded checkpoint
@@ -165,13 +166,15 @@ class CheckpointIO(ABC):
else:
self.load_unsharded_optimizer(optimizer, checkpoint)
def save_optimizer(self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
gather_dtensor=True,
prefix: str = None,
size_per_shard: int = 1024):
def save_optimizer(
self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
gather_dtensor=True,
prefix: str = None,
size_per_shard: int = 1024,
):
"""
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
@@ -207,7 +210,6 @@ class CheckpointIO(ABC):
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
pass
@abstractmethod
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
@@ -220,11 +222,17 @@ class CheckpointIO(ABC):
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
pass
@abstractmethod
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
size_per_shard: int, use_safetensors: bool):
def save_sharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
prefix: Optional[str],
size_per_shard: int,
use_safetensors: bool,
):
"""
Save model to sharded checkpoint.
@@ -236,7 +244,6 @@ class CheckpointIO(ABC):
size_per_shard (int): size per shard in MB.
use_safetensors (bool): whether to use safe tensors.
"""
pass
@abstractmethod
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
@@ -249,7 +256,6 @@ class CheckpointIO(ABC):
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
use_safetensors (bool): whether to use safe tensors.
"""
pass
# ========================================================
# Abstract methods for optimizer loading/saving implementation
@@ -265,7 +271,6 @@ class CheckpointIO(ABC):
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
"""
pass
@abstractmethod
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
@@ -276,11 +281,11 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
"""
pass
@abstractmethod
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
size_per_shard: int):
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
):
"""
Save optimizer to sharded checkpoint.
@@ -291,7 +296,6 @@ class CheckpointIO(ABC):
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass
@abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
@@ -303,7 +307,6 @@ class CheckpointIO(ABC):
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
"""
pass
# ============================================
# methods for loading and saving lr scheduler

View File

@@ -3,9 +3,8 @@ import logging
import os
from functools import reduce
from pathlib import Path
from typing import Iterator, Optional, OrderedDict, Tuple
from typing import Optional
import torch.distributed as dist
import torch.nn as nn
from torch.optim import Optimizer
@@ -16,7 +15,6 @@ from .index_file import CheckpointIndexFile
from .utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
get_shard_filename,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
@@ -33,7 +31,7 @@ from .utils import (
unwrap_optimizer,
)
__all__ = ['GeneralCheckpointIO']
__all__ = ["GeneralCheckpointIO"]
class GeneralCheckpointIO(CheckpointIO):
@@ -70,8 +68,10 @@ class GeneralCheckpointIO(CheckpointIO):
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
Lacking param group file under current directory.')
raise RuntimeError(
f"Invalid index file path {index_file_path} for an optimizer. \
Lacking param group file under current directory."
)
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
@@ -123,19 +123,23 @@ class GeneralCheckpointIO(CheckpointIO):
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
use_safetensors=False)
total_size = save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
use_safetensors=False,
)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint)
@@ -150,13 +154,15 @@ class GeneralCheckpointIO(CheckpointIO):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
def save_sharded_model(
self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
"""
implement this method as it can be supported by Huggingface model,
save shard model, save model to multiple files
@@ -175,26 +181,32 @@ class GeneralCheckpointIO(CheckpointIO):
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors,
)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint_path, is_master=True)
logging.info(f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
logging.info(
f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_model(self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True):
def load_sharded_model(
self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
"""
load shard model, load model from multiple files
"""
@@ -219,7 +231,11 @@ class GeneralCheckpointIO(CheckpointIO):
if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
if len(remain_keys) > 0:
error_msgs = 'Missing key(s) in state_dict: {}. '.format(', '.join(
'"{}"'.format(k) for k in missing_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
error_msgs = "Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
)
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
)
)

View File

@@ -1,5 +1,4 @@
import copy
import gc
import logging
import os
from pathlib import Path
@@ -35,9 +34,9 @@ from .utils import (
)
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
class HybridParallelCheckpointIO(GeneralCheckpointIO):
@@ -52,12 +51,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
"""
def __init__(self,
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True) -> None:
def __init__(
self,
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
super().__init__()
self.dp_group = dp_group
self.pp_group = pp_group
@@ -68,17 +69,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
self.use_zero = (zero_stage > 0)
self.use_zero = zero_stage > 0
self.verbose = verbose
self.working_to_master_map = None
self.master_to_working_map = None
self.coordinator = DistCoordinator()
@staticmethod
def _model_sharder(model: nn.Module,
prefix: str = '',
keep_vars: bool = False,
size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
def _model_sharder(
model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
@@ -103,8 +103,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(model.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
extra_state = model.get_extra_state()
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
if block is not None:
@@ -114,20 +116,20 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
@staticmethod
def _optimizer_sharder(optimizer: OptimizerWrapper,
use_zero: bool,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
size_per_shard: int = 1024):
def _optimizer_sharder(
optimizer: OptimizerWrapper,
use_zero: bool,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
size_per_shard: int = 1024,
):
# An internel method that breaks state_dict of optimizer into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info
for param, state in optimizer.optim.state.items():
if param is None:
continue
@@ -136,15 +138,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
else:
working_param = param
param_id = param_info['param2id'][id(working_param)]
original_shape = param_info['param2shape'][id(working_param)]
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(state,
working_param,
original_shape=original_shape,
dp_group=dp_group,
tp_group=tp_group,
use_zero=use_zero,
inplace=False)
param_id = param_info["param2id"][id(working_param)]
original_shape = param_info["param2shape"][id(working_param)]
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
state,
working_param,
original_shape=original_shape,
dp_group=dp_group,
tp_group=tp_group,
use_zero=use_zero,
inplace=False,
)
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None:
@@ -153,13 +157,15 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def save_sharded_model(self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False) -> None:
def save_sharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
@@ -194,24 +200,28 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = (self.tp_rank == 0)
control_saving = self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
# When pipeline is used, each stage produces its own shard files and index files.
@@ -228,15 +238,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True,
)
if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
assert (
self.dp_rank == 0 and self.tp_rank == 0
), "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
@@ -259,9 +273,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
save_config_file(model, checkpoint)
rmtree(tmp_index_file_folder)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}.")
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
"""
@@ -305,11 +321,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
missing_keys = []
load_state_dict_into_model(model,
state_dict,
missing_keys=missing_keys,
strict=strict,
load_sub_module=True)
load_state_dict_into_model(
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
)
loaded_file.add(filename)
# Load parameters.
@@ -319,15 +333,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Load buffers.
non_persistent_buffers = set()
for n, m in model.named_modules():
non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set)
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers:
_load(name)
# Load extra states.
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
if getattr(model.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
_load(extra_state_key)
# Update master params if mixed-precision training is enabled.
@@ -352,12 +368,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def save_sharded_optimizer(self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024):
def save_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
@@ -393,18 +411,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
dp_group=self.dp_group,
tp_group=self.tp_group,
master_to_working_map=self.master_to_working_map,
size_per_shard=size_per_shard)
size_per_shard=size_per_shard,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
control_saving = (self.dp_rank == 0 and self.tp_rank == 0)
control_saving = self.dp_rank == 0 and self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
)
if control_saving:
# Store param groups.
@@ -415,9 +436,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
if self.verbose:
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
# When pipeline is used, each stage produces its own shard files and index files.
@@ -433,15 +456,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
)
if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
assert (
self.dp_rank == 0 and self.tp_rank == 0
), "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
@@ -451,7 +478,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0:
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
@@ -470,9 +496,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
rmtree(tmp_index_file_folder)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}.")
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
"""
@@ -484,20 +512,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
prefix (str): Not used.
"""
def _get_param_id_from_optimizer_param(param: torch.Tensor,
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None):
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
):
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info['param2id'][id(working_param)]
return optimizer.param_info["param2id"][id(working_param)]
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {}
for pg in optimizer.optim.param_groups:
for param in pg['params']:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
id_map[param_id] = param
@@ -505,28 +534,30 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory.')
raise RuntimeError(
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory."
)
saved_groups = torch.load(param_group_path)
updated_groups = []
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change.
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({'param_groups': updated_groups})
optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
for pg in optimizer.optim.param_groups:
for param in pg['params']:
for param in pg["params"]:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
@@ -550,12 +581,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
working_param = self.master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info['param2shape'][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(state,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
inplace=True)
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
@@ -585,8 +614,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]):
def link_master_and_working_param(
self,
working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor],
):
"""
Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
This mapping can only be created when mixied precision is used.
@@ -604,7 +636,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.working_to_master_map[k] = v
else:
raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
)
self.master_to_working_map = dict()
for k, v in master_to_working_map.items():
@@ -614,12 +647,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.master_to_working_map[k] = v
else:
raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
)
@staticmethod
def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size,
dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool,
inplace: bool) -> OrderedDict:
def gather_from_sharded_optimizer_state(
state: OrderedDict,
param: torch.Tensor,
original_shape: torch.Size,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
use_zero: bool,
inplace: bool,
) -> OrderedDict:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
@@ -641,14 +681,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != 'step':
if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards.
if use_zero:
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
# Then gather TP shards.
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
@@ -661,9 +700,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
return state_
def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size,
original_shape: torch.Size, device: torch.device,
inplace: bool) -> OrderedDict:
def shard_from_complete_optimizer_state(
self,
state: OrderedDict,
current_shape: torch.Size,
original_shape: torch.Size,
device: torch.device,
inplace: bool,
) -> OrderedDict:
"""
With complete optimizer states of a specific parameter loaded from checkpoint,
slice out the sharded optimizer states kept by current device.
@@ -681,8 +725,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != 'step':
if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group.
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
if partition_dim is not None:

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union
from .utils import is_dtensor_checkpoint
__all__ = ['CheckpointIndexFile']
__all__ = ["CheckpointIndexFile"]
class CheckpointIndexFile:
@@ -50,7 +50,7 @@ class CheckpointIndexFile:
json_path (str): path to the json file.
"""
# load the json file
with open(json_path, 'r') as f:
with open(json_path, "r") as f:
index = json.load(f)
# assign attributes if exists
@@ -75,7 +75,7 @@ class CheckpointIndexFile:
index["weight_map"] = self.weight_map
# export the index file
with open(json_path, 'w') as f:
with open(json_path, "w") as f:
json.dump(index, f, indent=4)
def append_weight_map(self, param_name: str, shard_file: str):

View File

@@ -1,5 +1,4 @@
# coding=utf-8
import copy
import os
import re
from collections import abc as container_abcs
@@ -12,7 +11,7 @@ import torch
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface import OptimizerWrapper
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
@@ -55,7 +54,6 @@ def is_safetensors_available() -> bool:
bool: whether safetensors is available.
"""
try:
import safetensors
return True
except ImportError:
return False
@@ -71,7 +69,7 @@ def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool:
Returns:
bool: whether the checkpoint file is a dtensor checkpoint.
"""
if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'):
if checkpoint_file_path.endswith(".*.safetensors") or checkpoint_file_path.endswith(".*.bin"):
return True
else:
return False
@@ -87,7 +85,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
Returns:
bool: whether the checkpoint file is a safetensor checkpoint.
"""
if checkpoint_file_path.endswith('.safetensors'):
if checkpoint_file_path.endswith(".safetensors"):
return True
else:
return False
@@ -113,8 +111,9 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
partition_dim = dim
break
if partition_dim is not None:
assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \
f"The parameter isn't evenly distributed among tensor parallel group: \
assert (
original_shape[partition_dim] == tp_size * current_shape[partition_dim]
), f"The parameter isn't evenly distributed among tensor parallel group: \
shape before sharding {original_shape}, shape after sharding {current_shape}"
return partition_dim
@@ -124,24 +123,22 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
# Helper classes and functions for saving shard file
# ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper):
'''
"""
Unwrap a wrapped optimizer.
This method should be used before saving/loading it to/from sharded checkpoints.
'''
"""
unwrapped_optim = optimizer.optim
return unwrapped_optim
class StateDictSharder:
def __init__(self, size_per_shard: int) -> None:
self.max_shard_size = size_per_shard
self.current_block = OrderedDict()
self.current_block_size = 0
def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
@@ -159,13 +156,11 @@ class StateDictSharder:
return ret_block, ret_block_size
def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]:
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state.values():
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
@@ -217,14 +212,16 @@ def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> to
return param_
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
use_safetensors: bool = False,
use_pp_format: bool = False) -> int:
'''
def save_state_dict_shards(
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
use_safetensors: bool = False,
use_pp_format: bool = False,
) -> int:
"""
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
@@ -237,7 +234,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
Returns:
int: the total size of shards
'''
"""
total_size = 0
shard_filenames = []
@@ -288,7 +285,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
"""
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
states = state_dict['state']
states = state_dict["state"]
state_dict_sharder = StateDictSharder(max_shard_size)
for param_id, state in states.items():
@@ -316,9 +313,11 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
"""
if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith('.safetensors'), \
"safetensors only supports .safetensors suffix for checkpoint file."
assert checkpoint_file_path.endswith(
".safetensors"
), "safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file as safe_save_file
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
else:
torch.save(state_dict, checkpoint_file_path)
@@ -336,11 +335,13 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None:
torch.save(param_groups, group_file_path)
def clean_folder(checkpoint_path: str,
weights_name: str,
shard_filenames: List[str],
is_master: bool = True,
use_pp_format: bool = False):
def clean_folder(
checkpoint_path: str,
weights_name: str,
shard_filenames: List[str],
is_master: bool = True,
use_pp_format: bool = False,
):
"""
Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.
@@ -362,8 +363,12 @@ def clean_folder(checkpoint_path: str,
else:
# When this checkpoint is created by pipeline parallel process, the pattern is a little different.
reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename)
and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None):
if (
filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in shard_filenames
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)
@@ -412,7 +417,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
size_per_shard (int): size per shard in MB.
"""
root_path = index_file.root_path
output_root_path = root_path.joinpath('dtensor')
output_root_path = root_path.joinpath("dtensor")
# create directory
output_root_path.mkdir(exist_ok=True)
@@ -432,7 +437,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
# update the weight map
# * means all shards
ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
ckpt_file_name_in_weight_map = "dtensor/" + generate_dtensor_file_name(name, "*", use_safetensors)
index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
@@ -447,15 +452,14 @@ def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
str: checkpoint file suffix.
"""
if use_safetensors:
return '.safetensors'
return ".safetensors"
else:
return '.bin'
return ".bin"
def generate_checkpoint_shard_file_name(index: int,
total_number: int,
use_safetensors: bool,
prefix: str = None) -> str:
def generate_checkpoint_shard_file_name(
index: int, total_number: int, use_safetensors: bool, prefix: str = None
) -> str:
"""
Generate checkpoint shard file name.
@@ -489,7 +493,7 @@ def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: boo
str: dtensor file name.
"""
suffix = get_checkpoint_file_suffix(use_safetensors)
return f'{param_name}.{index}.{suffix}'
return f"{param_name}.{index}.{suffix}"
# ========================================
@@ -506,21 +510,21 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
if use_safetensors:
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import safe_open
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file, map_location=torch.device('cpu'))
return torch.load(checkpoint_file, map_location=torch.device("cpu"))
def load_state_dict_into_model(model: nn.Module,
state_dict: torch.Tensor,
missing_keys: List,
strict: bool = False,
load_sub_module: bool = True):
def load_state_dict_into_model(
model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True
):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants.
@@ -536,7 +540,7 @@ def load_state_dict_into_model(model: nn.Module,
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
metadata = getattr(state_dict, "_metadata", None)
state_dict = OrderedDict(state_dict)
if metadata is not None:
state_dict._metadata = metadata
@@ -560,10 +564,12 @@ def load_state_dict_into_model(model: nn.Module,
if strict:
if len(unexpected_keys) > 0:
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
'"{}"'.format(k) for k in unexpected_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
error_msgs = "Unexpected key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in unexpected_keys)
)
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
)
def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
@@ -573,9 +579,9 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
saved_groups = torch.load(param_group_path, map_location=torch.device('cpu'))
saved_groups = torch.load(param_group_path, map_location=torch.device("cpu"))
if not isinstance(saved_groups, List):
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
raise ValueError(f"The param_groups saved at {param_group_path} is not of List type")
# The params in param_groups are in the form of pytorch tensors.
# For more details, please view source code of Optimizer class in pytorch.
@@ -584,26 +590,30 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Check the compatibility of saved_groups and param_groups.
if len(param_groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of original parameter groups")
param_lens = (len(g['params']) for g in param_groups)
saved_lens = (len(g['params']) for g in saved_groups)
param_lens = (len(g["params"]) for g in param_groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")
raise ValueError(
"loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
)
# Creating mapping from id to parameters.
id_map = {
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
)), chain.from_iterable((g['params'] for g in param_groups)))
old_id: p
for old_id, p in zip(
chain.from_iterable((g["params"] for g in saved_groups)),
chain.from_iterable((g["params"] for g in param_groups)),
)
}
# Update parameter groups, setting their 'params' value.
def update_group(group, new_group):
new_group['params'] = group['params']
new_group["params"] = group["params"]
return new_group
updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
optimizer.__dict__.update({'param_groups': updated_groups})
optimizer.__dict__.update({"param_groups": updated_groups})
return id_map
@@ -628,7 +638,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
if (key != "step"):
if key != "step":
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
@@ -662,8 +672,8 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
"""
# Do the cleaning up as in src code of Pytorch.
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
optimizer.defaults.setdefault('differentiable', False)
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
optimizer.defaults.setdefault("differentiable", False)
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
@@ -686,20 +696,20 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
return False, None
elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory
index_files = list(checkpoint_path.glob('*.index.*json'))
index_files = list(checkpoint_path.glob("*.index.*json"))
# if we found a .index.json file, make sure there is only one
if len(index_files) > 0:
assert len(
index_files
) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}'
assert (
len(index_files) == 1
), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}"
if len(index_files) == 1:
return True, index_files[0]
else:
return False, None
else:
raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.')
raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.")
def load_state_dict(checkpoint_file_path: Path):
@@ -713,14 +723,17 @@ def load_state_dict(checkpoint_file_path: Path):
dict: state dict.
"""
assert not is_dtensor_checkpoint(checkpoint_file_path), \
f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.'
assert not is_dtensor_checkpoint(
checkpoint_file_path
), f"Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline."
if is_safetensor_checkpoint(checkpoint_file_path):
assert is_safetensors_available(), \
f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.'
assert (
is_safetensors_available()
), f"Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors."
# load with safetensors
from safetensors import safe_open
state_dict = {}
with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f:
for k in f.keys():
@@ -729,7 +742,7 @@ def load_state_dict(checkpoint_file_path: Path):
else:
# load with torch
return torch.load(checkpoint_file_path, map_location=torch.device('cpu'))
return torch.load(checkpoint_file_path, map_location=torch.device("cpu"))
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str: