diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 4a7efc165..ce01ad111 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -12,7 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO -from colossalai.checkpoint_io.utils import get_base_filenames, get_shard_filename, save_state_dict +from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device @@ -76,14 +76,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO): model: GeminiDDP, checkpoint_path: str, gather_dtensor: bool = False, - variant: Optional[str] = None, + prefix: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): """ Save sharded model """ state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32) - weights_name, save_index_file = get_base_filenames(variant, use_safetensors) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) total_size = 0 index_file = CheckpointIndexFile(checkpoint_path) for idx, shard_pair in enumerate(state_dict_shard): diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index b317ccf48..a18073db6 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -32,7 +32,6 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): """ Save model to checkpoint but only on master process. """ - # the model should be unwrapped in self.load_model via ModelWrapper.unwrap if self.coordinator.is_master(): super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) @@ -54,11 +53,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): model: nn.Module, checkpoint_path: str, gather_dtensor: bool = False, - variant: Optional[str] = None, + prefix: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): + """ + Save model to checkpoint but only on master process. + """ if self.coordinator.is_master(): - super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors) + super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors) + + def save_sharded_optimier(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, + size_per_shard: int): + """ + Save optimizer to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) class TorchDDPModel(ModelWrapper): diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 8d534ea4c..ebd03b6ea 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,9 +1,9 @@ +import warnings from pathlib import Path from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union import torch import torch.nn as nn -import warnings from packaging import version from torch.distributed import ProcessGroup @@ -69,7 +69,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) - def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str], + def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str], size_per_shard: int, use_safetensors: bool): """ Save model to checkpoint but only on master process. @@ -87,13 +87,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): """ raise NotImplementedError("Sharded model checkpoint is not supported yet.") - def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, + size_per_shard: int): """ Save optimizer to checkpoint but only on master process. """ raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int): + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int): """ Load optimizer to checkpoint but only on master process. """ diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index fbc8fc542..9d513043f 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -103,7 +103,7 @@ class CheckpointIO(ABC): checkpoint: str, shard: bool = False, gather_dtensor: bool = True, - variant: str = None, + prefix: str = None, size_per_shard: int = 1024, use_safetensors: bool = False): """ @@ -128,7 +128,7 @@ class CheckpointIO(ABC): multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure that the checkpoint path is a directory path instead of a file path. gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. - variant (str): If specified, weights are saved in the format pytorch_model..bin. Default: None. + prefix (str): If specified, weights are saved in the format pytorch_model..bin. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True. use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved """ @@ -137,11 +137,11 @@ class CheckpointIO(ABC): model = model.unwrap() if shard: - self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors) + self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors) else: self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) - def load_optimizer(self, optimizer: Optimizer, checkpoint: str): + def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024): """ Load optimizer from checkpoint. @@ -157,7 +157,7 @@ class CheckpointIO(ABC): if index_file_exists: # the existence of index file means it is a sharded checkpoint - self.load_sharded_optimizer(optimizer, index_file_path) + self.load_sharded_optimizer(optimizer, index_file_path, prefix, size_per_shard) else: self.load_unsharded_optimizer(optimizer, checkpoint) @@ -218,7 +218,7 @@ class CheckpointIO(ABC): pass @abstractmethod - def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str], + def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str], size_per_shard: int, use_safetensors: bool): """ Save model to sharded checkpoint. diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 2cc9c3faa..d8e133313 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -11,15 +11,21 @@ from torch.optim import Optimizer from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( - get_base_filenames, + get_model_base_filenames, + get_optimizer_base_filenames, get_shard_filename, has_index_file, is_safetensors_available, + load_param_groups_into_optimizer, load_shard_state_dict, load_state_dict, load_state_dict_into_model, + load_states_into_optimizer, + save_param_groups, save_state_dict, - shard_checkpoint, + shard_model_checkpoint, + shard_optimizer_checkpoint, + sharded_optimizer_loading_epilogue, ) __all__ = ['GeneralCheckpointIO'] @@ -44,12 +50,30 @@ class GeneralCheckpointIO(CheckpointIO): # save the checkpoint save_state_dict(state_dict, checkpoint, use_safetensors) - def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): - raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int): + """ + Load sharded optimizer with the given path to index file. + """ + optimizer.load_state_dict + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) - def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): - checkpoint = load_state_dict(checkpoint) - optimizer.load_state_dict(checkpoint) + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \ + Lacking param group file under current directory.') + id_map = load_param_groups_into_optimizer(optimizer, param_group_path) + + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + + for shard_file in checkpoint_files: + state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) + load_states_into_optimizer(optimizer, state_dict, id_map) + del state_dict + gc.collect() + + sharded_optimizer_loading_epilogue(optimizer) def save_sharded_optimizer( self, @@ -59,7 +83,54 @@ class GeneralCheckpointIO(CheckpointIO): prefix: str, size_per_shard: int, ): - raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + """ + Save sharded optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way + """ + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Offload optimizer states. States are broken into shards within max_shard_size. + state_dict = optimizer.state_dict() + sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard) + + # Preparing file paths and index file. + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + + # Store the information of param groups to param_group_file. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(state_dict, group_file_path) + + # Save shards of optimizer states. + total_size = 0 + for idx, shard_pair in enumerate(sharded_state): + shard, current_size = shard_pair + shard_file = get_shard_filename(states_name, idx) + total_size = total_size + current_size + for param_id in shard.keys(): + index_file.append_weight_map(str(param_id), shard_file) + + checkpoint_file_path = os.path.join(checkpoint, shard_file) + save_state_dict(shard, checkpoint_file_path, use_safetensors=False) + + # Wrap up index file. + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info(f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + checkpoint = load_state_dict(checkpoint) + optimizer.load_state_dict(checkpoint) def save_unsharded_optimizer( self, @@ -74,7 +145,7 @@ class GeneralCheckpointIO(CheckpointIO): model: nn.Module, checkpoint_path: str, gather_dtensor: bool = False, - variant: Optional[str] = None, + prefix: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): """ @@ -89,9 +160,9 @@ class GeneralCheckpointIO(CheckpointIO): # shard checkpoint state_dict = model.state_dict() - state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size) + state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size) - weights_name, save_index_file = get_base_filenames(variant, use_safetensors) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) total_size = 0 index_file = CheckpointIndexFile(checkpoint_path) for idx, shard_pair in enumerate(state_dict_shard): @@ -128,7 +199,7 @@ class GeneralCheckpointIO(CheckpointIO): # read checkpoint index file ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) - checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames() + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() missing_keys = [] for shard_file in checkpoint_files: diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index a41cc482e..388cf3fbe 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -111,7 +111,7 @@ class CheckpointIndexFile: return True return False - def get_checkpoint_fileanames(self) -> List[str]: + def get_checkpoint_filenames(self) -> List[str]: """ Get the set of checkpoint filenames in the weight map. @@ -159,6 +159,18 @@ class CheckpointIndexFile: """ return list(self.weight_map.keys()) + def get_param_group_filename(self) -> Union[str, None]: + """ + Get the file name of param_group file if this is a checkpoint for optimizer. + Returns: + str: param_group file name + """ + filename = self.metadata.get("param_groups", None) + if filename: + return str(self.root_path.joinpath(filename)) + else: + return None + def write_index_file(self, save_index_file): """ Write index file. diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 435feda4a..21b70343b 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,17 +1,24 @@ # coding=utf-8 import re +from collections import abc as container_abcs +from collections import defaultdict +from itertools import chain from pathlib import Path from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple import torch import torch.nn as nn +from torch.optim import Optimizer from colossalai.tensor.d_tensor.d_tensor import DTensor SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" +STATES_NAME = "pytorch_optim.bin" SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" +STATES_INDEX_NAME = "pytorch_optim.bin.index.json" +GROUP_FILE_NAME = "pytorch_optim_group.bin" # ====================================== # General helper functions @@ -81,7 +88,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: # ====================================== # Helper functions for saving shard file # ====================================== -def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: +def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. @@ -110,6 +117,50 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It yield current_block, current_block_size +def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: + """ + Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + """ + + # Only split state_dict['state']; state_dict['param_group'] is not considered in this function. + states = state_dict['state'] + + current_block = {} + current_block_size = 0 + + for param_id, state in states.items(): + + ret_block = None + ret_block_size = 0 + + # A state might contain more than one tensors. + # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' + state_size = 0 + isDTensor = False + for state_tensor in state.values(): + # If the states are stored as DTensors, mark isDTensor as true. + if type(state_tensor) == DTensor: + isDTensor = True + state_size += calculate_tensor_size(state_tensor) + + if not isDTensor: + + if current_block_size + state_size > max_shard_size: + ret_block = current_block + ret_block_size = current_block_size + current_block = {} + current_block_size = 0 + + current_block[param_id] = state + current_block_size += state_size + + if ret_block != None: + yield ret_block, ret_block_size + + yield current_block, current_block_size + + def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): """ load shard state dict into model @@ -179,6 +230,96 @@ def load_state_dict_into_model(model: nn.Module, model.__class__.__name__, "\n\t".join(error_msgs))) +def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict: + """ + Load information of param_groups into an initialized optimizer. + """ + + # Load list of param_groups from given file path. + # The params in saved_groups are in the form of integer indices. + saved_groups = torch.load(param_group_path) + if not isinstance(saved_groups, List): + raise ValueError(f'The param_groups saved at {param_group_path} is not of List type') + + # The params in param_groups are in the form of pytorch tensors. + # For more details, please view source code of Optimizer class in pytorch. + param_groups = optimizer.param_groups + + # Check the compatibility of saved_groups and param_groups. + if len(param_groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of original parameter groups") + param_lens = (len(g['params']) for g in param_groups) + saved_lens = (len(g['params']) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError("loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group") + + # Creating mapping from id to parameters. + id_map = { + old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups + )), chain.from_iterable((g['params'] for g in param_groups))) + } + + # Update parameter groups, setting their 'params' value. + def update_group(group, new_group): + new_group['params'] = group['params'] + return new_group + + updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)] + + optimizer.__dict__.update({'param_groups': updated_groups}) + return id_map + + +def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: dict): + r"""Copies states from `state_dict` into an Optimizer object. + + Args: + optimizer(Optimizer): An initialized Optimizer object to be loaded + state_dict(dict): a mapping from tensor index (an integer) + to its states to be loaded (a mapping from state name to a tensor). + id_map(dict): a mapping from tensor index (an integer) + to its corresponding parameter (a tensor) whose states will be updated. + """ + + def cast(param, value, key=None): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + # Floating-point types are a bit special here. They are the only ones + # that are assumed to always match the type of params. + # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 + if (key != "step"): + if param.is_floating_point(): + value = value.to(param.dtype) + value = value.to(param.device) + return value + elif isinstance(value, dict): + return {k: cast(param, v, key=k) for k, v in value.items()} + elif isinstance(value, container_abcs.Iterable): + return type(value)(cast(param, v) for v in value) + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + new_states = defaultdict(dict) + for k, v in state_dict.items(): + if k in id_map: + param = id_map[k] + new_states[param] = cast(param, v) + else: + new_states[k] = v + + optimzier.state.update(new_states) + + +def sharded_optimizer_loading_epilogue(optimizer: Optimizer): + # Do the cleaning up as in src code of Pytorch. + optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle. + optimizer.defaults.setdefault('differentiable', False) + + # ====================================== # Helper functions for saving state dict # ====================================== @@ -203,6 +344,18 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors torch.save(state_dict, checkpoint_file_path) +def save_param_groups(state_dict: dict, group_file_path: str) -> None: + """ + Save information of param_groups to given file path. + + Args: + state_dict (dict): state dict. + group_file_path (str): path to the group file. + """ + param_groups = state_dict["param_groups"] + torch.save(param_groups, group_file_path) + + def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: """ Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains @@ -392,28 +545,44 @@ def load_state_dict(checkpoint_file_path: Path): return torch.load(checkpoint_file_path) -def add_variant(weights_name: str, variant: Optional[str] = None) -> str: - if variant is not None and len(variant) > 0: +def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str: + if prefix is not None and len(prefix) > 0: splits = weights_name.split(".") - splits = splits[:-1] + [variant] + splits[-1:] + splits = splits[:-1] + [prefix] + splits[-1:] weights_name = ".".join(splits) return weights_name -def get_base_filenames(variant: str = None, use_safetensors: bool = False): +def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False): """ - generate base weight filenames + generate base model weight filenames """ weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME - weights_name = add_variant(weights_name, variant) + weights_name = add_prefix(weights_name, prefix) save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME - save_index_file = add_variant(save_index_file, variant) + save_index_file = add_prefix(save_index_file, prefix) return weights_name, save_index_file +def get_optimizer_base_filenames(prefix: str = None): + """ + generate base optimizer state filenames + """ + states_name = STATES_NAME + states_name = add_prefix(states_name, prefix) + + save_index_file = STATES_INDEX_NAME + save_index_file = add_prefix(save_index_file, prefix) + + param_group_file = GROUP_FILE_NAME + param_group_file = add_prefix(param_group_file, prefix) + + return states_name, save_index_file, param_group_file + + def get_shard_filename(weights_name: str, idx: int): """ get shard file name diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 9e973bb23..88e3673c1 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -60,7 +60,7 @@ def test_unsharded_checkpoint(use_safetensors: bool): @pytest.mark.parametrize('use_safetensors', [True, False]) -def test_sharded_checkpoint(use_safetensors: bool): +def test_sharded_model_checkpoint(use_safetensors: bool): # create a model and optimizer model = resnet18() optimizer = Adam(model.parameters(), lr=0.001) @@ -100,3 +100,101 @@ def test_sharded_checkpoint(use_safetensors: bool): # check for model and optimizer state dict recursively check_state_dict_equal(model.state_dict(), new_model.state_dict()) check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) + + +def test_sharded_optimizer_checkpoint(): + + # create a model and optimizer + model = resnet18() + optimizer = Adam(model.parameters(), lr=0.001) + + # create test data sample + x = torch.randn(1, 3, 224, 224) + + # run fwd and bwd + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # create temp directories for checkpoint + model_ckpt_dir = tempfile.TemporaryDirectory() + optimizer_ckpt_dir = tempfile.TemporaryDirectory() + + # save the model and optimizer + ckpt_io = GeneralCheckpointIO() + + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10) + + # create new model + new_model = resnet18() + new_optimizer = Adam(new_model.parameters(), lr=0.001) + + ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True) + ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name)) + + # check for model and optimizer state dict recursively + check_state_dict_equal(model.state_dict(), new_model.state_dict()) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) + + # continue running fwd and bwd + for _ in range(5): + y = new_model(x) + loss = y.sum() + loss.backward() + new_optimizer.step() + + # save the newly got optimizer + ckpt_io.save_model(new_model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) + ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10) + + # create another new model + new_new_model = resnet18() + new_new_optimizer = Adam(new_new_model.parameters(), lr=0.001) + + ckpt_io.load_model(new_new_model, str(model_ckpt_dir.name), strict=True) + ckpt_io.load_optimizer(new_new_optimizer, str(optimizer_ckpt_dir.name)) + + # check for model and optimizer state dict recursively + check_state_dict_equal(new_model.state_dict(), new_new_model.state_dict()) + check_state_dict_equal(new_optimizer.state_dict(), new_new_optimizer.state_dict()) + + +def test_sharded_optimizer_multiple_param_groups(): + + # create a model and optimizer + model = resnet18() + optimizer = Adam([{'params': model.layer1.parameters()}, \ + {'params': model.layer2.parameters(), 'lr': 0.002}], lr=0.001) + + # create test data sample + x = torch.randn(1, 3, 224, 224) + + # run fwd and bwd + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # create temp directories for checkpoint + model_ckpt_dir = tempfile.TemporaryDirectory() + optimizer_ckpt_dir = tempfile.TemporaryDirectory() + + # save the model and optimizer + ckpt_io = GeneralCheckpointIO() + + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10) + + # create new model + new_model = resnet18() + new_optimizer = Adam([{'params': new_model.layer1.parameters()}, \ + {'params': new_model.layer2.parameters(), 'lr': 0.002}], lr=0.001) + + ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True) + ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name)) + + # check for model and optimizer state dict recursively + check_state_dict_equal(model.state_dict(), new_model.state_dict()) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())