mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[checkpointio] sharded optimizer checkpoint for DDP plugin (#4002)
This commit is contained in:
@@ -8,6 +8,8 @@ from typing import Iterator, Optional, OrderedDict, Tuple
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .utils import (
|
||||
@@ -50,11 +52,15 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
# save the checkpoint
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file.
|
||||
"""
|
||||
optimizer.load_state_dict
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = optimizer.optim
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||
|
||||
|
Reference in New Issue
Block a user