[checkpointio] sharded optimizer checkpoint for DDP plugin (#4002)

This commit is contained in:
Baizhou Zhang 2023-06-16 14:14:05 +08:00 committed by GitHub
parent 725af3eeeb
commit 822c3d4d66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 79 additions and 34 deletions

View File

@ -9,6 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper
from .accelerator import Accelerator from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory from .mixed_precision import MixedPrecision, mixed_precision_factory
@ -165,11 +166,11 @@ class Booster:
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model) return self.plugin.no_sync(model)
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
"""Load model from checkpoint. """Load model from checkpoint.
Args: Args:
model (nn.Module): A model boosted by Booster. model (nn.Module or ModelWrapper): A model boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path. checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
strict (bool, optional): whether to strictly enforce that the keys strict (bool, optional): whether to strictly enforce that the keys
@ -179,24 +180,34 @@ class Booster:
self.checkpoint_io.load_model(model, checkpoint, strict) self.checkpoint_io.load_model(model, checkpoint, strict)
def save_model(self, def save_model(self,
model: nn.Module, model: Union[nn.Module, ModelWrapper],
checkpoint: str, checkpoint: str,
prefix: str = None,
shard: bool = False, shard: bool = False,
size_per_shard: int = 1024): gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False):
"""Save model to checkpoint. """Save model to checkpoint.
Args: Args:
model (nn.Module): A model boosted by Booster. model (nn.Module or ModelWrapper): A model boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path. checkpoint (str): Path to the checkpoint. It must be a local path.
It is a file path if ``shard=False``. Otherwise, it is a directory path. It is a file path if ``shard=False``. Otherwise, it is a directory path.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
shard (bool, optional): Whether to save checkpoint a sharded way. shard (bool, optional): Whether to save checkpoint a sharded way.
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
""" """
self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard) self.checkpoint_io.save_model(model,
checkpoint=checkpoint,
shard=shard,
gather_dtensor=gather_dtensor,
prefix=prefix,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str): def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""Load optimizer from checkpoint. """Load optimizer from checkpoint.
@ -205,12 +216,21 @@ class Booster:
optimizer (Optimizer): An optimizer boosted by Booster. optimizer (Optimizer): An optimizer boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path. checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
""" """
self.checkpoint_io.load_optimizer(optimizer, checkpoint) self.checkpoint_io.load_optimizer(optimizer, checkpoint)
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): def save_optimizer(self,
"""Save optimizer to checkpoint. optimizer: Optimizer,
Warning: Saving sharded optimizer checkpoint is not supported yet. checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024):
"""
Save optimizer to checkpoint.
Args: Args:
optimizer (Optimizer): An optimizer boosted by Booster. optimizer (Optimizer): An optimizer boosted by Booster.
@ -218,9 +238,12 @@ class Booster:
It is a file path if ``shard=False``. Otherwise, it is a directory path. It is a file path if ``shard=False``. Otherwise, it is a directory path.
shard (bool, optional): Whether to save checkpoint a sharded way. shard (bool, optional): Whether to save checkpoint a sharded way.
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
""" """
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard) self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""Save lr scheduler to checkpoint. """Save lr scheduler to checkpoint.

View File

@ -52,7 +52,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
def save_sharded_model(self, def save_sharded_model(self,
model: nn.Module, model: nn.Module,
checkpoint_path: str, checkpoint_path: str,
gather_dtensor: bool = False, gather_dtensor: bool = True,
prefix: Optional[str] = None, prefix: Optional[str] = None,
max_shard_size: int = 1024, max_shard_size: int = 1024,
use_safetensors: bool = False): use_safetensors: bool = False):
@ -62,8 +62,12 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master(): if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors) super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
def save_sharded_optimier(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, def save_sharded_optimizer(self,
size_per_shard: int): optimizer: Optimizer,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024):
""" """
Save optimizer to checkpoint but only on master process. Save optimizer to checkpoint but only on master process.
""" """

View File

@ -148,6 +148,9 @@ class CheckpointIO(ABC):
Args: Args:
optimizer (Optimizer): optimizer to be loaded. optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
""" """
index_file_exists, index_file_path = has_index_file(checkpoint) index_file_exists, index_file_path = has_index_file(checkpoint)
@ -157,7 +160,7 @@ class CheckpointIO(ABC):
if index_file_exists: if index_file_exists:
# the existence of index file means it is a sharded checkpoint # the existence of index file means it is a sharded checkpoint
self.load_sharded_optimizer(optimizer, index_file_path, prefix, size_per_shard) self.load_sharded_optimizer(optimizer, index_file_path, prefix)
else: else:
self.load_unsharded_optimizer(optimizer, checkpoint) self.load_unsharded_optimizer(optimizer, checkpoint)
@ -251,7 +254,7 @@ class CheckpointIO(ABC):
# ======================================================== # ========================================================
@abstractmethod @abstractmethod
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int): def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
""" """
Load optimizer from sharded checkpoint. Load optimizer from sharded checkpoint.
@ -259,7 +262,6 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded. optimizer (Optimizer): optimizer to be loaded.
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint. prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
""" """
pass pass

View File

@ -8,6 +8,8 @@ from typing import Iterator, Optional, OrderedDict, Tuple
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
from .checkpoint_io_base import CheckpointIO from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
from .utils import ( from .utils import (
@ -50,11 +52,15 @@ class GeneralCheckpointIO(CheckpointIO):
# save the checkpoint # save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors) save_state_dict(state_dict, checkpoint, use_safetensors)
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int): def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
""" """
Load sharded optimizer with the given path to index file. Load sharded optimizer with the given path to index file.
""" """
optimizer.load_state_dict
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.optim
# Read checkpoint index file. # Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)

View File

@ -139,6 +139,12 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
state_size = 0 state_size = 0
isDTensor = False isDTensor = False
for state_tensor in state.values(): for state_tensor in state.values():
# When state_tensor is None (e.g., a SGD optimizer with momentum set to 0),
# The calculation of tensor size should be skipped to avoid error.
if state_tensor is None:
continue
# If the states are stored as DTensors, mark isDTensor as true. # If the states are stored as DTensors, mark isDTensor as true.
if type(state_tensor) == DTensor: if type(state_tensor) == DTensor:
isDTensor = True isDTensor = True
@ -271,7 +277,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
return id_map return id_map
def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: dict): def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict):
r"""Copies states from `state_dict` into an Optimizer object. r"""Copies states from `state_dict` into an Optimizer object.
Args: Args:
@ -311,10 +317,16 @@ def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: d
else: else:
new_states[k] = v new_states[k] = v
optimzier.state.update(new_states) optimizer.state.update(new_states)
def sharded_optimizer_loading_epilogue(optimizer: Optimizer): def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
r"""Do the cleaning up work after state_dict has been loaded into optimizer
Args:
optimizer(Optimizer): An optimizer object whose state has just been loaded.
"""
# Do the cleaning up as in src code of Pytorch. # Do the cleaning up as in src code of Pytorch.
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle. optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
optimizer.defaults.setdefault('differentiable', False) optimizer.defaults.setdefault('differentiable', False)

View File

@ -13,7 +13,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad
@parameterize('shard', [True, False]) @parameterize('shard', [True, False])
def check_torch_ddp_checkpointIO(shard: bool): @parameterize('size_per_shard', [16, 128])
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
model = resnet18() model = resnet18()
@ -38,10 +39,8 @@ def check_torch_ddp_checkpointIO(shard: bool):
model_ckpt_path = f"{tempdir}/model" model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer" optimizer_ckpt_path = f"{tempdir}/optimizer"
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler" lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
booster.save_model(model, model_ckpt_path, shard=shard) booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
if not shard: booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
booster.save_optimizer(optimizer, optimizer_ckpt_path)
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path) booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
dist.barrier() dist.barrier()
@ -55,7 +54,6 @@ def check_torch_ddp_checkpointIO(shard: bool):
booster.load_model(new_model, model_ckpt_path) booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
if not shard:
booster.load_optimizer(new_optimizer, optimizer_ckpt_path) booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path) booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)