[booster] implemented the torch ddd + resnet example (#3232)

* [booster] implemented the torch ddd + resnet example

* polish code
This commit is contained in:
Frank Lee
2023-03-27 10:24:14 +08:00
committed by GitHub
parent 1a229045af
commit 73d3e4d309
22 changed files with 608 additions and 128 deletions

View File

@@ -1,13 +1,15 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
from typing import Any, Union
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import ModelWrapper
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile']
@@ -37,15 +39,15 @@ class CheckpointIO(ABC):
>>>
>>> # save optimizer to checkpoint
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
"""
# ======================================
# Abstract methods for implementation
# Public methods
# ======================================
@abstractmethod
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
def load_model(self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
strict: bool = True) -> Union[nn.Module, ModelWrapper]:
"""
Load model from checkpoint.
@@ -59,14 +61,26 @@ class CheckpointIO(ABC):
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
pass
ckpt_path = Path(checkpoint)
is_sharded = self.is_sharded_checkpoint(ckpt_path)
origin_model = model
if isinstance(model, ModelWrapper):
model = model.unwrap()
if is_sharded:
self.load_sharded_model(model, ckpt_path, strict)
else:
self.load_unsharded_model(model, ckpt_path, strict)
return origin_model
@abstractmethod
def save_model(self,
model: nn.Module,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
prefix: str = None,
shard: bool = False,
prefix: str = None,
size_per_shard: int = 1024):
"""
Save model to checkpoint.
@@ -83,17 +97,24 @@ class CheckpointIO(ABC):
Args:
model (nn.Module): model to be saved.
checkpoint: checkpoint path. The checkpoint path can be :
checkpoint (str): checkpoint path. The checkpoint path can be :
1. a file path, e.g. 'model.pt'
2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True.
shard: whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
that the checkpoint path is a directory path instead of a file path.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
prefix (str): prefix for the model checkpoint file name when shard=True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
"""
pass
@abstractmethod
if isinstance(model, ModelWrapper):
model = model.unwrap()
if shard:
self.save_sharded_model(model, checkpoint, prefix, size_per_shard)
else:
self.save_unsharded_model(model, checkpoint)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""
Load optimizer from checkpoint.
@@ -102,19 +123,139 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
"""
pass
ckpt_path = Path(checkpoint)
is_sharded = self.is_sharded_checkpoint(ckpt_path)
@abstractmethod
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
if is_sharded:
self.load_sharded_optimizer(optimizer, ckpt_path)
else:
self.load_unsharded_optimizer(optimizer, ckpt_path)
def save_optimizer(self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
prefix: str = None,
size_per_shard: int = 1024):
"""
Save optimizer to checkpoint.
Args:
optimizer (Optimizer): optimizer to be saved.
checkpoint: checkpoint path. The checkpoint path can be :
checkpoint (str): checkpoint path. The checkpoint path can be :
1. a file path, e.g. 'model.pt'
2. a path to a json file which defines the index to the sharded checkpoint for the optimizer
3. a path to a folder containing a unique .index.json file for sharded checkpoint
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
multiple files. The optimizer shards will be specificed by a `optimizer.index.json` file.
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""
if shard:
self.save_sharded_optimizer(optimizer, checkpoint, prefix, size_per_shard)
else:
self.save_unsharded_optimizer(optimizer, checkpoint)
# ========================================================
# Abstract methods for model loading/saving implementation
# ========================================================
@abstractmethod
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
"""
Load model from sharded checkpoint.
Args:
model (nn.Module): model to be loaded.
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
"""
pass
@abstractmethod
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
"""
Load model from unsharded checkpoint.
Args:
model (nn.Module): model to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
pass
@abstractmethod
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
"""
Save model to sharded checkpoint.
Args:
model (nn.Module): model to be saved.
checkpoint (Path): checkpoint path. It should be a directory path.
prefix (str): prefix for the model checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass
@abstractmethod
def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
"""
Save model to unsharded checkpoint.
Args:
model (nn.Module): model to be saved.
checkpoint (Path): checkpoint path. It should be a single file path pointing to a model weight binary.
"""
pass
# ========================================================
# Abstract methods for optimizer loading/saving implementation
# ========================================================
@abstractmethod
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
"""
Load optimizer from sharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass
@abstractmethod
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
"""
Load optimizer from unsharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
"""
pass
@abstractmethod
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
"""
Save optimizer to sharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be saved.
checkpoint (Path): checkpoint path. It should be a directory path.
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass
@abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
"""
Save optimizer to unsharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be saved.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
"""
pass