mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
[checkpointio] sharded optimizer checkpoint for DDP plugin (#4002)
This commit is contained in:
parent
725af3eeeb
commit
822c3d4d66
@ -9,6 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||||
|
from colossalai.interface import ModelWrapper
|
||||||
|
|
||||||
from .accelerator import Accelerator
|
from .accelerator import Accelerator
|
||||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||||
@ -165,11 +166,11 @@ class Booster:
|
|||||||
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||||
return self.plugin.no_sync(model)
|
return self.plugin.no_sync(model)
|
||||||
|
|
||||||
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
|
||||||
"""Load model from checkpoint.
|
"""Load model from checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): A model boosted by Booster.
|
model (nn.Module or ModelWrapper): A model boosted by Booster.
|
||||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||||
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
||||||
strict (bool, optional): whether to strictly enforce that the keys
|
strict (bool, optional): whether to strictly enforce that the keys
|
||||||
@ -179,24 +180,34 @@ class Booster:
|
|||||||
self.checkpoint_io.load_model(model, checkpoint, strict)
|
self.checkpoint_io.load_model(model, checkpoint, strict)
|
||||||
|
|
||||||
def save_model(self,
|
def save_model(self,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, ModelWrapper],
|
||||||
checkpoint: str,
|
checkpoint: str,
|
||||||
prefix: str = None,
|
|
||||||
shard: bool = False,
|
shard: bool = False,
|
||||||
size_per_shard: int = 1024):
|
gather_dtensor: bool = True,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
size_per_shard: int = 1024,
|
||||||
|
use_safetensors: bool = False):
|
||||||
"""Save model to checkpoint.
|
"""Save model to checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): A model boosted by Booster.
|
model (nn.Module or ModelWrapper): A model boosted by Booster.
|
||||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||||
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
||||||
prefix (str, optional): A prefix added to parameter and buffer
|
|
||||||
names to compose the keys in state_dict. Defaults to None.
|
|
||||||
shard (bool, optional): Whether to save checkpoint a sharded way.
|
shard (bool, optional): Whether to save checkpoint a sharded way.
|
||||||
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
||||||
|
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
|
||||||
|
prefix (str, optional): A prefix added to parameter and buffer
|
||||||
|
names to compose the keys in state_dict. Defaults to None.
|
||||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||||
|
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
|
||||||
"""
|
"""
|
||||||
self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard)
|
self.checkpoint_io.save_model(model,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
shard=shard,
|
||||||
|
gather_dtensor=gather_dtensor,
|
||||||
|
prefix=prefix,
|
||||||
|
size_per_shard=size_per_shard,
|
||||||
|
use_safetensors=use_safetensors)
|
||||||
|
|
||||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||||
"""Load optimizer from checkpoint.
|
"""Load optimizer from checkpoint.
|
||||||
@ -205,12 +216,21 @@ class Booster:
|
|||||||
optimizer (Optimizer): An optimizer boosted by Booster.
|
optimizer (Optimizer): An optimizer boosted by Booster.
|
||||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||||
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
||||||
|
prefix (str, optional): A prefix added to parameter and buffer
|
||||||
|
names to compose the keys in state_dict. Defaults to None.
|
||||||
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||||
"""
|
"""
|
||||||
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
||||||
|
|
||||||
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
|
def save_optimizer(self,
|
||||||
"""Save optimizer to checkpoint.
|
optimizer: Optimizer,
|
||||||
Warning: Saving sharded optimizer checkpoint is not supported yet.
|
checkpoint: str,
|
||||||
|
shard: bool = False,
|
||||||
|
gather_dtensor: bool = True,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
size_per_shard: int = 1024):
|
||||||
|
"""
|
||||||
|
Save optimizer to checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): An optimizer boosted by Booster.
|
optimizer (Optimizer): An optimizer boosted by Booster.
|
||||||
@ -218,9 +238,12 @@ class Booster:
|
|||||||
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
||||||
shard (bool, optional): Whether to save checkpoint a sharded way.
|
shard (bool, optional): Whether to save checkpoint a sharded way.
|
||||||
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
||||||
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
||||||
|
prefix (str, optional): A prefix added to parameter and buffer
|
||||||
|
names to compose the keys in state_dict. Defaults to None.
|
||||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||||
"""
|
"""
|
||||||
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
|
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
|
||||||
|
|
||||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||||
"""Save lr scheduler to checkpoint.
|
"""Save lr scheduler to checkpoint.
|
||||||
|
@ -52,7 +52,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
def save_sharded_model(self,
|
def save_sharded_model(self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
gather_dtensor: bool = False,
|
gather_dtensor: bool = True,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
max_shard_size: int = 1024,
|
max_shard_size: int = 1024,
|
||||||
use_safetensors: bool = False):
|
use_safetensors: bool = False):
|
||||||
@ -62,8 +62,12 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
|
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
|
||||||
|
|
||||||
def save_sharded_optimier(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
|
def save_sharded_optimizer(self,
|
||||||
size_per_shard: int):
|
optimizer: Optimizer,
|
||||||
|
checkpoint: str,
|
||||||
|
gather_dtensor: bool = True,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
size_per_shard: int = 1024):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint but only on master process.
|
Save optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
|
@ -148,6 +148,9 @@ class CheckpointIO(ABC):
|
|||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): optimizer to be loaded.
|
optimizer (Optimizer): optimizer to be loaded.
|
||||||
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
||||||
|
prefix (str, optional): A prefix added to parameter and buffer
|
||||||
|
names to compose the keys in state_dict. Defaults to None.
|
||||||
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||||
"""
|
"""
|
||||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||||
|
|
||||||
@ -157,7 +160,7 @@ class CheckpointIO(ABC):
|
|||||||
|
|
||||||
if index_file_exists:
|
if index_file_exists:
|
||||||
# the existence of index file means it is a sharded checkpoint
|
# the existence of index file means it is a sharded checkpoint
|
||||||
self.load_sharded_optimizer(optimizer, index_file_path, prefix, size_per_shard)
|
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
|
||||||
else:
|
else:
|
||||||
self.load_unsharded_optimizer(optimizer, checkpoint)
|
self.load_unsharded_optimizer(optimizer, checkpoint)
|
||||||
|
|
||||||
@ -251,7 +254,7 @@ class CheckpointIO(ABC):
|
|||||||
# ========================================================
|
# ========================================================
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||||
"""
|
"""
|
||||||
Load optimizer from sharded checkpoint.
|
Load optimizer from sharded checkpoint.
|
||||||
|
|
||||||
@ -259,7 +262,6 @@ class CheckpointIO(ABC):
|
|||||||
optimizer (Optimizer): optimizer to be loaded.
|
optimizer (Optimizer): optimizer to be loaded.
|
||||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||||
prefix (str): prefix for the optimizer checkpoint.
|
prefix (str): prefix for the optimizer checkpoint.
|
||||||
size_per_shard (int): size per shard in MB.
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -8,6 +8,8 @@ from typing import Iterator, Optional, OrderedDict, Tuple
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from colossalai.interface import OptimizerWrapper
|
||||||
|
|
||||||
from .checkpoint_io_base import CheckpointIO
|
from .checkpoint_io_base import CheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -50,11 +52,15 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
# save the checkpoint
|
# save the checkpoint
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||||
"""
|
"""
|
||||||
Load sharded optimizer with the given path to index file.
|
Load sharded optimizer with the given path to index file.
|
||||||
"""
|
"""
|
||||||
optimizer.load_state_dict
|
|
||||||
|
# If optimizer is wrapped, unwrap it.
|
||||||
|
if isinstance(optimizer, OptimizerWrapper):
|
||||||
|
optimizer = optimizer.optim
|
||||||
|
|
||||||
# Read checkpoint index file.
|
# Read checkpoint index file.
|
||||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||||
|
|
||||||
|
@ -139,6 +139,12 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
|||||||
state_size = 0
|
state_size = 0
|
||||||
isDTensor = False
|
isDTensor = False
|
||||||
for state_tensor in state.values():
|
for state_tensor in state.values():
|
||||||
|
|
||||||
|
# When state_tensor is None (e.g., a SGD optimizer with momentum set to 0),
|
||||||
|
# The calculation of tensor size should be skipped to avoid error.
|
||||||
|
if state_tensor is None:
|
||||||
|
continue
|
||||||
|
|
||||||
# If the states are stored as DTensors, mark isDTensor as true.
|
# If the states are stored as DTensors, mark isDTensor as true.
|
||||||
if type(state_tensor) == DTensor:
|
if type(state_tensor) == DTensor:
|
||||||
isDTensor = True
|
isDTensor = True
|
||||||
@ -271,7 +277,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
|
|||||||
return id_map
|
return id_map
|
||||||
|
|
||||||
|
|
||||||
def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: dict):
|
def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict):
|
||||||
r"""Copies states from `state_dict` into an Optimizer object.
|
r"""Copies states from `state_dict` into an Optimizer object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -311,10 +317,16 @@ def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: d
|
|||||||
else:
|
else:
|
||||||
new_states[k] = v
|
new_states[k] = v
|
||||||
|
|
||||||
optimzier.state.update(new_states)
|
optimizer.state.update(new_states)
|
||||||
|
|
||||||
|
|
||||||
def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
|
def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
|
||||||
|
r"""Do the cleaning up work after state_dict has been loaded into optimizer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer(Optimizer): An optimizer object whose state has just been loaded.
|
||||||
|
"""
|
||||||
|
|
||||||
# Do the cleaning up as in src code of Pytorch.
|
# Do the cleaning up as in src code of Pytorch.
|
||||||
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||||
optimizer.defaults.setdefault('differentiable', False)
|
optimizer.defaults.setdefault('differentiable', False)
|
||||||
|
@ -13,7 +13,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad
|
|||||||
|
|
||||||
|
|
||||||
@parameterize('shard', [True, False])
|
@parameterize('shard', [True, False])
|
||||||
def check_torch_ddp_checkpointIO(shard: bool):
|
@parameterize('size_per_shard', [16, 128])
|
||||||
|
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
|
||||||
plugin = TorchDDPPlugin()
|
plugin = TorchDDPPlugin()
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
model = resnet18()
|
model = resnet18()
|
||||||
@ -38,11 +39,9 @@ def check_torch_ddp_checkpointIO(shard: bool):
|
|||||||
model_ckpt_path = f"{tempdir}/model"
|
model_ckpt_path = f"{tempdir}/model"
|
||||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||||
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
|
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
|
||||||
booster.save_model(model, model_ckpt_path, shard=shard)
|
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||||
if not shard:
|
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||||
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
|
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
|
||||||
booster.save_optimizer(optimizer, optimizer_ckpt_path)
|
|
||||||
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
new_model = resnet18()
|
new_model = resnet18()
|
||||||
@ -55,11 +54,10 @@ def check_torch_ddp_checkpointIO(shard: bool):
|
|||||||
booster.load_model(new_model, model_ckpt_path)
|
booster.load_model(new_model, model_ckpt_path)
|
||||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||||
|
|
||||||
if not shard:
|
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
|
||||||
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
|
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
|
||||||
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
|
Loading…
Reference in New Issue
Block a user