[checkpointio] General Checkpointing of Sharded Optimizers (#3984)

This commit is contained in:
Baizhou Zhang
2023-06-15 15:21:26 +08:00
committed by GitHub
parent 8bcad73677
commit c9cff7e7fa
8 changed files with 399 additions and 38 deletions

View File

@@ -12,7 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import get_base_filenames, get_shard_filename, save_state_dict
from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
@@ -76,14 +76,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
model: GeminiDDP,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
Save sharded model
"""
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):

View File

@@ -32,7 +32,6 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
@@ -54,11 +53,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors)
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
def save_sharded_optimier(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
size_per_shard: int):
"""
Save optimizer to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
class TorchDDPModel(ModelWrapper):

View File

@@ -1,9 +1,9 @@
import warnings
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import warnings
from packaging import version
from torch.distributed import ProcessGroup
@@ -69,7 +69,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
size_per_shard: int, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
@@ -87,13 +87,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
size_per_shard: int):
"""
Save optimizer to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
"""
Load optimizer to checkpoint but only on master process.
"""