mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[booster] implemented the torch ddd + resnet example (#3232)
* [booster] implemented the torch ddd + resnet example * polish code
This commit is contained in:
@@ -1,13 +1,15 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.interface import ModelWrapper
|
||||
|
||||
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile']
|
||||
|
||||
|
||||
@@ -37,15 +39,15 @@ class CheckpointIO(ABC):
|
||||
>>>
|
||||
>>> # save optimizer to checkpoint
|
||||
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
|
||||
|
||||
"""
|
||||
|
||||
# ======================================
|
||||
# Abstract methods for implementation
|
||||
# Public methods
|
||||
# ======================================
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||
def load_model(self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
strict: bool = True) -> Union[nn.Module, ModelWrapper]:
|
||||
"""
|
||||
Load model from checkpoint.
|
||||
|
||||
@@ -59,14 +61,26 @@ class CheckpointIO(ABC):
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
"""
|
||||
pass
|
||||
ckpt_path = Path(checkpoint)
|
||||
is_sharded = self.is_sharded_checkpoint(ckpt_path)
|
||||
|
||||
origin_model = model
|
||||
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
|
||||
if is_sharded:
|
||||
self.load_sharded_model(model, ckpt_path, strict)
|
||||
else:
|
||||
self.load_unsharded_model(model, ckpt_path, strict)
|
||||
|
||||
return origin_model
|
||||
|
||||
@abstractmethod
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
prefix: str = None,
|
||||
shard: bool = False,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024):
|
||||
"""
|
||||
Save model to checkpoint.
|
||||
@@ -83,17 +97,24 @@ class CheckpointIO(ABC):
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to be saved.
|
||||
checkpoint: checkpoint path. The checkpoint path can be :
|
||||
checkpoint (str): checkpoint path. The checkpoint path can be :
|
||||
1. a file path, e.g. 'model.pt'
|
||||
2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True.
|
||||
shard: whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
||||
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
||||
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
|
||||
that the checkpoint path is a directory path instead of a file path.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
||||
prefix (str): prefix for the model checkpoint file name when shard=True. Default: None.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
|
||||
if shard:
|
||||
self.save_sharded_model(model, checkpoint, prefix, size_per_shard)
|
||||
else:
|
||||
self.save_unsharded_model(model, checkpoint)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
"""
|
||||
Load optimizer from checkpoint.
|
||||
@@ -102,19 +123,139 @@ class CheckpointIO(ABC):
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
|
||||
"""
|
||||
pass
|
||||
ckpt_path = Path(checkpoint)
|
||||
is_sharded = self.is_sharded_checkpoint(ckpt_path)
|
||||
|
||||
@abstractmethod
|
||||
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
|
||||
if is_sharded:
|
||||
self.load_sharded_optimizer(optimizer, ckpt_path)
|
||||
else:
|
||||
self.load_unsharded_optimizer(optimizer, ckpt_path)
|
||||
|
||||
def save_optimizer(self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024):
|
||||
"""
|
||||
Save optimizer to checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be saved.
|
||||
checkpoint: checkpoint path. The checkpoint path can be :
|
||||
checkpoint (str): checkpoint path. The checkpoint path can be :
|
||||
1. a file path, e.g. 'model.pt'
|
||||
2. a path to a json file which defines the index to the sharded checkpoint for the optimizer
|
||||
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
||||
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
||||
multiple files. The optimizer shards will be specificed by a `optimizer.index.json` file.
|
||||
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
||||
"""
|
||||
if shard:
|
||||
self.save_sharded_optimizer(optimizer, checkpoint, prefix, size_per_shard)
|
||||
else:
|
||||
self.save_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
# ========================================================
|
||||
# Abstract methods for model loading/saving implementation
|
||||
# ========================================================
|
||||
@abstractmethod
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
||||
"""
|
||||
Load model from sharded checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to be loaded.
|
||||
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
||||
"""
|
||||
Load model from unsharded checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to be loaded.
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
"""
|
||||
Save model to sharded checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to be saved.
|
||||
checkpoint (Path): checkpoint path. It should be a directory path.
|
||||
prefix (str): prefix for the model checkpoint.
|
||||
size_per_shard (int): size per shard in MB.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
|
||||
"""
|
||||
Save model to unsharded checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to be saved.
|
||||
checkpoint (Path): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
"""
|
||||
pass
|
||||
|
||||
# ========================================================
|
||||
# Abstract methods for optimizer loading/saving implementation
|
||||
# ========================================================
|
||||
|
||||
@abstractmethod
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
"""
|
||||
Load optimizer from sharded checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
prefix (str): prefix for the optimizer checkpoint.
|
||||
size_per_shard (int): size per shard in MB.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
"""
|
||||
Load optimizer from unsharded checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
"""
|
||||
Save optimizer to sharded checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be saved.
|
||||
checkpoint (Path): checkpoint path. It should be a directory path.
|
||||
prefix (str): prefix for the optimizer checkpoint.
|
||||
size_per_shard (int): size per shard in MB.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
"""
|
||||
Save optimizer to unsharded checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be saved.
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
Reference in New Issue
Block a user