mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,4 +3,4 @@ from .general_checkpoint_io import GeneralCheckpointIO
|
||||
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
|
||||
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']
|
||||
__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"]
|
||||
|
@@ -11,7 +11,7 @@ from colossalai.interface import ModelWrapper
|
||||
|
||||
from .utils import has_index_file
|
||||
|
||||
__all__ = ['CheckpointIO']
|
||||
__all__ = ["CheckpointIO"]
|
||||
|
||||
|
||||
class CheckpointIO(ABC):
|
||||
@@ -61,10 +61,9 @@ class CheckpointIO(ABC):
|
||||
# ======================================
|
||||
# Public methods
|
||||
# ======================================
|
||||
def load_model(self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
strict: bool = True) -> Union[nn.Module, ModelWrapper]:
|
||||
def load_model(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
|
||||
) -> Union[nn.Module, ModelWrapper]:
|
||||
"""
|
||||
Load model from checkpoint.
|
||||
|
||||
@@ -98,14 +97,16 @@ class CheckpointIO(ABC):
|
||||
|
||||
return origin_model
|
||||
|
||||
def save_model(self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
def save_model(
|
||||
self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
):
|
||||
"""
|
||||
Save model to checkpoint.
|
||||
|
||||
@@ -157,7 +158,7 @@ class CheckpointIO(ABC):
|
||||
|
||||
if Path(checkpoint).is_dir() and not index_file_exists:
|
||||
# if the checkpoint is a directory and there is no index file, raise error
|
||||
raise ValueError(f'Cannot find index file in {checkpoint}')
|
||||
raise ValueError(f"Cannot find index file in {checkpoint}")
|
||||
|
||||
if index_file_exists:
|
||||
# the existence of index file means it is a sharded checkpoint
|
||||
@@ -165,13 +166,15 @@ class CheckpointIO(ABC):
|
||||
else:
|
||||
self.load_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_optimizer(self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor=True,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024):
|
||||
def save_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor=True,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
"""
|
||||
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
|
||||
|
||||
@@ -207,7 +210,6 @@ class CheckpointIO(ABC):
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||
@@ -220,11 +222,17 @@ class CheckpointIO(ABC):
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
|
||||
size_per_shard: int, use_safetensors: bool):
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool,
|
||||
prefix: Optional[str],
|
||||
size_per_shard: int,
|
||||
use_safetensors: bool,
|
||||
):
|
||||
"""
|
||||
Save model to sharded checkpoint.
|
||||
|
||||
@@ -236,7 +244,6 @@ class CheckpointIO(ABC):
|
||||
size_per_shard (int): size per shard in MB.
|
||||
use_safetensors (bool): whether to use safe tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
@@ -249,7 +256,6 @@ class CheckpointIO(ABC):
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
||||
use_safetensors (bool): whether to use safe tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
# ========================================================
|
||||
# Abstract methods for optimizer loading/saving implementation
|
||||
@@ -265,7 +271,6 @@ class CheckpointIO(ABC):
|
||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
prefix (str): prefix for the optimizer checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
@@ -276,11 +281,11 @@ class CheckpointIO(ABC):
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
|
||||
size_per_shard: int):
|
||||
def save_sharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||
):
|
||||
"""
|
||||
Save optimizer to sharded checkpoint.
|
||||
|
||||
@@ -291,7 +296,6 @@ class CheckpointIO(ABC):
|
||||
prefix (str): prefix for the optimizer checkpoint.
|
||||
size_per_shard (int): size per shard in MB.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
|
||||
@@ -303,7 +307,6 @@ class CheckpointIO(ABC):
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
||||
"""
|
||||
pass
|
||||
|
||||
# ============================================
|
||||
# methods for loading and saving lr scheduler
|
||||
|
@@ -3,9 +3,8 @@ import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional, OrderedDict, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
@@ -16,7 +15,6 @@ from .index_file import CheckpointIndexFile
|
||||
from .utils import (
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
is_safetensors_available,
|
||||
load_param_groups_into_optimizer,
|
||||
load_shard_state_dict,
|
||||
@@ -33,7 +31,7 @@ from .utils import (
|
||||
unwrap_optimizer,
|
||||
)
|
||||
|
||||
__all__ = ['GeneralCheckpointIO']
|
||||
__all__ = ["GeneralCheckpointIO"]
|
||||
|
||||
|
||||
class GeneralCheckpointIO(CheckpointIO):
|
||||
@@ -70,8 +68,10 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
|
||||
Lacking param group file under current directory.')
|
||||
raise RuntimeError(
|
||||
f"Invalid index file path {index_file_path} for an optimizer. \
|
||||
Lacking param group file under current directory."
|
||||
)
|
||||
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
|
||||
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
@@ -123,19 +123,23 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
|
||||
# Save shards of optimizer states.
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
total_size = save_state_dict_shards(sharded_state_dict=sharded_state,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=True,
|
||||
use_safetensors=False)
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=sharded_state,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=True,
|
||||
use_safetensors=False,
|
||||
)
|
||||
|
||||
# Wrap up index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
logging.info(
|
||||
f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
@@ -150,13 +154,15 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
# TODO(FrankLeeeee): handle distributed tensors
|
||||
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
|
||||
|
||||
def save_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
):
|
||||
"""
|
||||
implement this method as it can be supported by Huggingface model,
|
||||
save shard model, save model to multiple files
|
||||
@@ -175,26 +181,32 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
|
||||
# Save shards of optimizer states.
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=True,
|
||||
use_safetensors=use_safetensors)
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=True,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint_path, is_master=True)
|
||||
logging.info(f"The model is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
logging.info(
|
||||
f"The model is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
def load_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True):
|
||||
def load_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
):
|
||||
"""
|
||||
load shard model, load model from multiple files
|
||||
"""
|
||||
@@ -219,7 +231,11 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
if strict:
|
||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||
if len(remain_keys) > 0:
|
||||
error_msgs = 'Missing key(s) in state_dict: {}. '.format(', '.join(
|
||||
'"{}"'.format(k) for k in missing_keys))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
error_msgs = "Missing key(s) in state_dict: {}. ".format(
|
||||
", ".join('"{}"'.format(k) for k in missing_keys)
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||
)
|
||||
)
|
||||
|
@@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -35,9 +34,9 @@ from .utils import (
|
||||
)
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
|
||||
class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
@@ -52,12 +51,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
verbose: bool = True) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dp_group = dp_group
|
||||
self.pp_group = pp_group
|
||||
@@ -68,17 +69,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
self.dp_size = dist.get_world_size(dp_group)
|
||||
self.pp_size = dist.get_world_size(pp_group)
|
||||
self.tp_size = dist.get_world_size(tp_group)
|
||||
self.use_zero = (zero_stage > 0)
|
||||
self.use_zero = zero_stage > 0
|
||||
self.verbose = verbose
|
||||
self.working_to_master_map = None
|
||||
self.master_to_working_map = None
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
@staticmethod
|
||||
def _model_sharder(model: nn.Module,
|
||||
prefix: str = '',
|
||||
keep_vars: bool = False,
|
||||
size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
def _model_sharder(
|
||||
model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
# An internel method that breaks state_dict of model into shards within limited size.
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
@@ -103,8 +103,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
# Save extra states.
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(model.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
if (
|
||||
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
extra_state = model.get_extra_state()
|
||||
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
@@ -114,20 +116,20 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
@staticmethod
|
||||
def _optimizer_sharder(optimizer: OptimizerWrapper,
|
||||
use_zero: bool,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
|
||||
size_per_shard: int = 1024):
|
||||
|
||||
def _optimizer_sharder(
|
||||
optimizer: OptimizerWrapper,
|
||||
use_zero: bool,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
# An internel method that breaks state_dict of optimizer into shards within limited size.
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
param_info = optimizer.param_info
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
@@ -136,15 +138,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
param_id = param_info['param2id'][id(working_param)]
|
||||
original_shape = param_info['param2shape'][id(working_param)]
|
||||
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(state,
|
||||
working_param,
|
||||
original_shape=original_shape,
|
||||
dp_group=dp_group,
|
||||
tp_group=tp_group,
|
||||
use_zero=use_zero,
|
||||
inplace=False)
|
||||
param_id = param_info["param2id"][id(working_param)]
|
||||
original_shape = param_info["param2shape"][id(working_param)]
|
||||
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
||||
state,
|
||||
working_param,
|
||||
original_shape=original_shape,
|
||||
dp_group=dp_group,
|
||||
tp_group=tp_group,
|
||||
use_zero=use_zero,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
|
||||
if block is not None:
|
||||
@@ -153,13 +157,15 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
def save_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False) -> None:
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Save sharded model checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
@@ -194,24 +200,28 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = (self.tp_rank == 0)
|
||||
control_saving = self.tp_rank == 0
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, save the model shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors)
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose:
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
@@ -228,15 +238,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
use_pp_format=True)
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
use_pp_format=True,
|
||||
)
|
||||
if control_saving:
|
||||
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
|
||||
assert (
|
||||
self.dp_rank == 0 and self.tp_rank == 0
|
||||
), "The saving process should have both dp_rank and tp_rank as 0."
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
@@ -259,9 +273,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
save_config_file(model, checkpoint)
|
||||
rmtree(tmp_index_file_folder)
|
||||
if self.verbose:
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}.")
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
|
||||
"""
|
||||
@@ -305,11 +321,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||
missing_keys = []
|
||||
|
||||
load_state_dict_into_model(model,
|
||||
state_dict,
|
||||
missing_keys=missing_keys,
|
||||
strict=strict,
|
||||
load_sub_module=True)
|
||||
load_state_dict_into_model(
|
||||
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
|
||||
)
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Load parameters.
|
||||
@@ -319,15 +333,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# Load buffers.
|
||||
non_persistent_buffers = set()
|
||||
for n, m in model.named_modules():
|
||||
non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set)
|
||||
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
|
||||
for name, buf in model.named_buffers():
|
||||
if buf is not None and name not in non_persistent_buffers:
|
||||
_load(name)
|
||||
|
||||
# Load extra states.
|
||||
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(model.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
if (
|
||||
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
_load(extra_state_key)
|
||||
|
||||
# Update master params if mixed-precision training is enabled.
|
||||
@@ -352,12 +368,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
if self.verbose:
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
def save_sharded_optimizer(self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024):
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
@@ -393,18 +411,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
master_to_working_map=self.master_to_working_map,
|
||||
size_per_shard=size_per_shard)
|
||||
size_per_shard=size_per_shard,
|
||||
)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = (self.dp_rank == 0 and self.tp_rank == 0)
|
||||
control_saving = self.dp_rank == 0 and self.tp_rank == 0
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, save the optimizer shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving)
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
# Store param groups.
|
||||
@@ -415,9 +436,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
if self.verbose:
|
||||
logging.info(f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
logging.info(
|
||||
f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
@@ -433,15 +456,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True)
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
|
||||
assert (
|
||||
self.dp_rank == 0 and self.tp_rank == 0
|
||||
), "The saving process should have both dp_rank and tp_rank as 0."
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
@@ -451,7 +478,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
# The global master rank integrates the index files and clean the folder.
|
||||
if self.pp_rank == 0:
|
||||
|
||||
final_index_file = CheckpointIndexFile(checkpoint)
|
||||
final_index_file.append_meta_data("total_size", 0)
|
||||
|
||||
@@ -470,9 +496,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
rmtree(tmp_index_file_folder)
|
||||
|
||||
if self.verbose:
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}.")
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
||||
"""
|
||||
@@ -484,20 +512,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
prefix (str): Not used.
|
||||
"""
|
||||
|
||||
def _get_param_id_from_optimizer_param(param: torch.Tensor,
|
||||
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None):
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
):
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
return optimizer.param_info['param2id'][id(working_param)]
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
|
||||
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
|
||||
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
|
||||
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
|
||||
id_map = {}
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg['params']:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
|
||||
id_map[param_id] = param
|
||||
|
||||
@@ -505,28 +534,30 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
ckpt_root_path = ckpt_index_file.root_path
|
||||
weight_map = ckpt_index_file.weight_map
|
||||
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
|
||||
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
|
||||
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
|
||||
Lacking param group file under current directory.')
|
||||
raise RuntimeError(
|
||||
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
|
||||
Lacking param group file under current directory."
|
||||
)
|
||||
saved_groups = torch.load(param_group_path)
|
||||
|
||||
updated_groups = []
|
||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||
# obtain updated param group
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change.
|
||||
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({'param_groups': updated_groups})
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
# Load saved states to optimizer.
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg['params']:
|
||||
for param in pg["params"]:
|
||||
if param is None:
|
||||
continue
|
||||
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
|
||||
@@ -550,12 +581,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
working_param = self.master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info['param2shape'][id(working_param)]
|
||||
sharded_state = self.shard_from_complete_optimizer_state(state,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device=device,
|
||||
inplace=True)
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.shard_from_complete_optimizer_state(
|
||||
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
|
||||
)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
@@ -585,8 +614,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
|
||||
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]):
|
||||
def link_master_and_working_param(
|
||||
self,
|
||||
working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
|
||||
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor],
|
||||
):
|
||||
"""
|
||||
Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
|
||||
This mapping can only be created when mixied precision is used.
|
||||
@@ -604,7 +636,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
self.working_to_master_map[k] = v
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
|
||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
|
||||
)
|
||||
|
||||
self.master_to_working_map = dict()
|
||||
for k, v in master_to_working_map.items():
|
||||
@@ -614,12 +647,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
self.master_to_working_map[k] = v
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
|
||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size,
|
||||
dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool,
|
||||
inplace: bool) -> OrderedDict:
|
||||
def gather_from_sharded_optimizer_state(
|
||||
state: OrderedDict,
|
||||
param: torch.Tensor,
|
||||
original_shape: torch.Size,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
use_zero: bool,
|
||||
inplace: bool,
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With given parameter and its optimizer states, gather the complete optimizer state for saving.
|
||||
|
||||
@@ -641,14 +681,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# First gather Zero shards.
|
||||
if use_zero:
|
||||
v = v.cuda()
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=dp_group)
|
||||
v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
|
||||
# Then gather TP shards.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
|
||||
@@ -661,9 +700,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
return state_
|
||||
|
||||
def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size,
|
||||
original_shape: torch.Size, device: torch.device,
|
||||
inplace: bool) -> OrderedDict:
|
||||
def shard_from_complete_optimizer_state(
|
||||
self,
|
||||
state: OrderedDict,
|
||||
current_shape: torch.Size,
|
||||
original_shape: torch.Size,
|
||||
device: torch.device,
|
||||
inplace: bool,
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With complete optimizer states of a specific parameter loaded from checkpoint,
|
||||
slice out the sharded optimizer states kept by current device.
|
||||
@@ -681,8 +725,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# Shard state along tensor parallel group.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
|
||||
if partition_dim is not None:
|
||||
|
@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union
|
||||
|
||||
from .utils import is_dtensor_checkpoint
|
||||
|
||||
__all__ = ['CheckpointIndexFile']
|
||||
__all__ = ["CheckpointIndexFile"]
|
||||
|
||||
|
||||
class CheckpointIndexFile:
|
||||
@@ -50,7 +50,7 @@ class CheckpointIndexFile:
|
||||
json_path (str): path to the json file.
|
||||
"""
|
||||
# load the json file
|
||||
with open(json_path, 'r') as f:
|
||||
with open(json_path, "r") as f:
|
||||
index = json.load(f)
|
||||
|
||||
# assign attributes if exists
|
||||
@@ -75,7 +75,7 @@ class CheckpointIndexFile:
|
||||
index["weight_map"] = self.weight_map
|
||||
|
||||
# export the index file
|
||||
with open(json_path, 'w') as f:
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(index, f, indent=4)
|
||||
|
||||
def append_weight_map(self, param_name: str, shard_file: str):
|
||||
|
@@ -1,5 +1,4 @@
|
||||
# coding=utf-8
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
from collections import abc as container_abcs
|
||||
@@ -12,7 +11,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
@@ -55,7 +54,6 @@ def is_safetensors_available() -> bool:
|
||||
bool: whether safetensors is available.
|
||||
"""
|
||||
try:
|
||||
import safetensors
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
@@ -71,7 +69,7 @@ def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool:
|
||||
Returns:
|
||||
bool: whether the checkpoint file is a dtensor checkpoint.
|
||||
"""
|
||||
if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'):
|
||||
if checkpoint_file_path.endswith(".*.safetensors") or checkpoint_file_path.endswith(".*.bin"):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@@ -87,7 +85,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
||||
Returns:
|
||||
bool: whether the checkpoint file is a safetensor checkpoint.
|
||||
"""
|
||||
if checkpoint_file_path.endswith('.safetensors'):
|
||||
if checkpoint_file_path.endswith(".safetensors"):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@@ -113,8 +111,9 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
|
||||
partition_dim = dim
|
||||
break
|
||||
if partition_dim is not None:
|
||||
assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \
|
||||
f"The parameter isn't evenly distributed among tensor parallel group: \
|
||||
assert (
|
||||
original_shape[partition_dim] == tp_size * current_shape[partition_dim]
|
||||
), f"The parameter isn't evenly distributed among tensor parallel group: \
|
||||
shape before sharding {original_shape}, shape after sharding {current_shape}"
|
||||
|
||||
return partition_dim
|
||||
@@ -124,24 +123,22 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
|
||||
# Helper classes and functions for saving shard file
|
||||
# ======================================
|
||||
def unwrap_optimizer(optimizer: OptimizerWrapper):
|
||||
'''
|
||||
"""
|
||||
Unwrap a wrapped optimizer.
|
||||
This method should be used before saving/loading it to/from sharded checkpoints.
|
||||
'''
|
||||
"""
|
||||
|
||||
unwrapped_optim = optimizer.optim
|
||||
return unwrapped_optim
|
||||
|
||||
|
||||
class StateDictSharder:
|
||||
|
||||
def __init__(self, size_per_shard: int) -> None:
|
||||
self.max_shard_size = size_per_shard
|
||||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
|
||||
def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
|
||||
|
||||
tensor_size = calculate_tensor_size(tensor)
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
@@ -159,13 +156,11 @@ class StateDictSharder:
|
||||
return ret_block, ret_block_size
|
||||
|
||||
def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]:
|
||||
|
||||
# A state might contain more than one tensors.
|
||||
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
|
||||
state_size = 0
|
||||
isDTensor = False
|
||||
for state_tensor in state.values():
|
||||
|
||||
# When state_tensor is not of Tensor class,
|
||||
# e.g., a SGD optimizer with momentum set to 0 can have None as state
|
||||
# The calculation of tensor size should be skipped to avoid error.
|
||||
@@ -217,14 +212,16 @@ def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> to
|
||||
return param_
|
||||
|
||||
|
||||
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
|
||||
checkpoint: str,
|
||||
index_file: "CheckpointIndexFile",
|
||||
base_filename: str,
|
||||
is_master: bool,
|
||||
use_safetensors: bool = False,
|
||||
use_pp_format: bool = False) -> int:
|
||||
'''
|
||||
def save_state_dict_shards(
|
||||
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
|
||||
checkpoint: str,
|
||||
index_file: "CheckpointIndexFile",
|
||||
base_filename: str,
|
||||
is_master: bool,
|
||||
use_safetensors: bool = False,
|
||||
use_pp_format: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
|
||||
Args:
|
||||
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
|
||||
@@ -237,7 +234,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
|
||||
|
||||
Returns:
|
||||
int: the total size of shards
|
||||
'''
|
||||
"""
|
||||
|
||||
total_size = 0
|
||||
shard_filenames = []
|
||||
@@ -288,7 +285,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
||||
"""
|
||||
|
||||
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
|
||||
states = state_dict['state']
|
||||
states = state_dict["state"]
|
||||
state_dict_sharder = StateDictSharder(max_shard_size)
|
||||
|
||||
for param_id, state in states.items():
|
||||
@@ -316,9 +313,11 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
|
||||
"""
|
||||
if use_safetensors:
|
||||
assert is_safetensors_available(), "safetensors is not available."
|
||||
assert checkpoint_file_path.endswith('.safetensors'), \
|
||||
"safetensors only supports .safetensors suffix for checkpoint file."
|
||||
assert checkpoint_file_path.endswith(
|
||||
".safetensors"
|
||||
), "safetensors only supports .safetensors suffix for checkpoint file."
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
|
||||
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(state_dict, checkpoint_file_path)
|
||||
@@ -336,11 +335,13 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None:
|
||||
torch.save(param_groups, group_file_path)
|
||||
|
||||
|
||||
def clean_folder(checkpoint_path: str,
|
||||
weights_name: str,
|
||||
shard_filenames: List[str],
|
||||
is_master: bool = True,
|
||||
use_pp_format: bool = False):
|
||||
def clean_folder(
|
||||
checkpoint_path: str,
|
||||
weights_name: str,
|
||||
shard_filenames: List[str],
|
||||
is_master: bool = True,
|
||||
use_pp_format: bool = False,
|
||||
):
|
||||
"""
|
||||
Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.
|
||||
|
||||
@@ -362,8 +363,12 @@ def clean_folder(checkpoint_path: str,
|
||||
else:
|
||||
# When this checkpoint is created by pipeline parallel process, the pattern is a little different.
|
||||
reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
|
||||
if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename)
|
||||
and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None):
|
||||
if (
|
||||
filename.startswith(weights_no_suffix)
|
||||
and os.path.isfile(full_filename)
|
||||
and filename not in shard_filenames
|
||||
and reg.fullmatch(filename_no_suffix) is not None
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
|
||||
@@ -412,7 +417,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
|
||||
size_per_shard (int): size per shard in MB.
|
||||
"""
|
||||
root_path = index_file.root_path
|
||||
output_root_path = root_path.joinpath('dtensor')
|
||||
output_root_path = root_path.joinpath("dtensor")
|
||||
|
||||
# create directory
|
||||
output_root_path.mkdir(exist_ok=True)
|
||||
@@ -432,7 +437,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
|
||||
|
||||
# update the weight map
|
||||
# * means all shards
|
||||
ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
|
||||
ckpt_file_name_in_weight_map = "dtensor/" + generate_dtensor_file_name(name, "*", use_safetensors)
|
||||
index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
|
||||
|
||||
|
||||
@@ -447,15 +452,14 @@ def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
|
||||
str: checkpoint file suffix.
|
||||
"""
|
||||
if use_safetensors:
|
||||
return '.safetensors'
|
||||
return ".safetensors"
|
||||
else:
|
||||
return '.bin'
|
||||
return ".bin"
|
||||
|
||||
|
||||
def generate_checkpoint_shard_file_name(index: int,
|
||||
total_number: int,
|
||||
use_safetensors: bool,
|
||||
prefix: str = None) -> str:
|
||||
def generate_checkpoint_shard_file_name(
|
||||
index: int, total_number: int, use_safetensors: bool, prefix: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate checkpoint shard file name.
|
||||
|
||||
@@ -489,7 +493,7 @@ def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: boo
|
||||
str: dtensor file name.
|
||||
"""
|
||||
suffix = get_checkpoint_file_suffix(use_safetensors)
|
||||
return f'{param_name}.{index}.{suffix}'
|
||||
return f"{param_name}.{index}.{suffix}"
|
||||
|
||||
|
||||
# ========================================
|
||||
@@ -506,21 +510,21 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
||||
if use_safetensors:
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
with safe_open(checkpoint_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata["format"] != "pt":
|
||||
raise NotImplementedError(
|
||||
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
|
||||
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
|
||||
)
|
||||
return safe_load_file(checkpoint_file)
|
||||
else:
|
||||
return torch.load(checkpoint_file, map_location=torch.device('cpu'))
|
||||
return torch.load(checkpoint_file, map_location=torch.device("cpu"))
|
||||
|
||||
|
||||
def load_state_dict_into_model(model: nn.Module,
|
||||
state_dict: torch.Tensor,
|
||||
missing_keys: List,
|
||||
strict: bool = False,
|
||||
load_sub_module: bool = True):
|
||||
def load_state_dict_into_model(
|
||||
model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True
|
||||
):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||
this module and its descendants.
|
||||
|
||||
@@ -536,7 +540,7 @@ def load_state_dict_into_model(model: nn.Module,
|
||||
error_msgs: List[str] = []
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = OrderedDict(state_dict)
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
@@ -560,10 +564,12 @@ def load_state_dict_into_model(model: nn.Module,
|
||||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
|
||||
'"{}"'.format(k) for k in unexpected_keys))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
error_msgs = "Unexpected key(s) in state_dict: {}. ".format(
|
||||
", ".join('"{}"'.format(k) for k in unexpected_keys)
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
|
||||
)
|
||||
|
||||
|
||||
def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
|
||||
@@ -573,9 +579,9 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
|
||||
|
||||
# Load list of param_groups from given file path.
|
||||
# The params in saved_groups are in the form of integer indices.
|
||||
saved_groups = torch.load(param_group_path, map_location=torch.device('cpu'))
|
||||
saved_groups = torch.load(param_group_path, map_location=torch.device("cpu"))
|
||||
if not isinstance(saved_groups, List):
|
||||
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
|
||||
raise ValueError(f"The param_groups saved at {param_group_path} is not of List type")
|
||||
|
||||
# The params in param_groups are in the form of pytorch tensors.
|
||||
# For more details, please view source code of Optimizer class in pytorch.
|
||||
@@ -584,26 +590,30 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
|
||||
# Check the compatibility of saved_groups and param_groups.
|
||||
if len(param_groups) != len(saved_groups):
|
||||
raise ValueError("loaded state dict has a different number of original parameter groups")
|
||||
param_lens = (len(g['params']) for g in param_groups)
|
||||
saved_lens = (len(g['params']) for g in saved_groups)
|
||||
param_lens = (len(g["params"]) for g in param_groups)
|
||||
saved_lens = (len(g["params"]) for g in saved_groups)
|
||||
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
||||
raise ValueError("loaded state dict contains a parameter group "
|
||||
"that doesn't match the size of optimizer's group")
|
||||
raise ValueError(
|
||||
"loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
|
||||
)
|
||||
|
||||
# Creating mapping from id to parameters.
|
||||
id_map = {
|
||||
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
|
||||
)), chain.from_iterable((g['params'] for g in param_groups)))
|
||||
old_id: p
|
||||
for old_id, p in zip(
|
||||
chain.from_iterable((g["params"] for g in saved_groups)),
|
||||
chain.from_iterable((g["params"] for g in param_groups)),
|
||||
)
|
||||
}
|
||||
|
||||
# Update parameter groups, setting their 'params' value.
|
||||
def update_group(group, new_group):
|
||||
new_group['params'] = group['params']
|
||||
new_group["params"] = group["params"]
|
||||
return new_group
|
||||
|
||||
updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
|
||||
|
||||
optimizer.__dict__.update({'param_groups': updated_groups})
|
||||
optimizer.__dict__.update({"param_groups": updated_groups})
|
||||
return id_map
|
||||
|
||||
|
||||
@@ -628,7 +638,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
|
||||
# Floating-point types are a bit special here. They are the only ones
|
||||
# that are assumed to always match the type of params.
|
||||
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
|
||||
if (key != "step"):
|
||||
if key != "step":
|
||||
if param.is_floating_point():
|
||||
value = value.to(param.dtype)
|
||||
value = value.to(param.device)
|
||||
@@ -662,8 +672,8 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
|
||||
"""
|
||||
|
||||
# Do the cleaning up as in src code of Pytorch.
|
||||
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
optimizer.defaults.setdefault('differentiable', False)
|
||||
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
optimizer.defaults.setdefault("differentiable", False)
|
||||
|
||||
|
||||
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
|
||||
@@ -686,20 +696,20 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
|
||||
return False, None
|
||||
elif checkpoint_path.is_dir():
|
||||
# check if there is only one a file ending with .index.json in this directory
|
||||
index_files = list(checkpoint_path.glob('*.index.*json'))
|
||||
index_files = list(checkpoint_path.glob("*.index.*json"))
|
||||
|
||||
# if we found a .index.json file, make sure there is only one
|
||||
if len(index_files) > 0:
|
||||
assert len(
|
||||
index_files
|
||||
) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}'
|
||||
assert (
|
||||
len(index_files) == 1
|
||||
), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}"
|
||||
|
||||
if len(index_files) == 1:
|
||||
return True, index_files[0]
|
||||
else:
|
||||
return False, None
|
||||
else:
|
||||
raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.')
|
||||
raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.")
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file_path: Path):
|
||||
@@ -713,14 +723,17 @@ def load_state_dict(checkpoint_file_path: Path):
|
||||
dict: state dict.
|
||||
"""
|
||||
|
||||
assert not is_dtensor_checkpoint(checkpoint_file_path), \
|
||||
f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.'
|
||||
assert not is_dtensor_checkpoint(
|
||||
checkpoint_file_path
|
||||
), f"Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline."
|
||||
|
||||
if is_safetensor_checkpoint(checkpoint_file_path):
|
||||
assert is_safetensors_available(), \
|
||||
f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.'
|
||||
assert (
|
||||
is_safetensors_available()
|
||||
), f"Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors."
|
||||
# load with safetensors
|
||||
from safetensors import safe_open
|
||||
|
||||
state_dict = {}
|
||||
with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
@@ -729,7 +742,7 @@ def load_state_dict(checkpoint_file_path: Path):
|
||||
|
||||
else:
|
||||
# load with torch
|
||||
return torch.load(checkpoint_file_path, map_location=torch.device('cpu'))
|
||||
return torch.load(checkpoint_file_path, map_location=torch.device("cpu"))
|
||||
|
||||
|
||||
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
|
||||
|
Reference in New Issue
Block a user