[checkpointio] sharded optimizer checkpoint for DDP plugin (#4002)

This commit is contained in:
Baizhou Zhang
2023-06-16 14:14:05 +08:00
committed by GitHub
parent 725af3eeeb
commit 822c3d4d66
6 changed files with 79 additions and 34 deletions

View File

@@ -8,6 +8,8 @@ from typing import Iterator, Optional, OrderedDict, Tuple
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
@@ -50,11 +52,15 @@ class GeneralCheckpointIO(CheckpointIO):
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
"""
Load sharded optimizer with the given path to index file.
"""
optimizer.load_state_dict
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.optim
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)