mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[checkpointio] support async model save (#6131)
* [checkpointio] support async model save * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -8,6 +8,7 @@ from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
|
||||
|
||||
@@ -58,9 +59,34 @@ class CheckpointIO(ABC):
|
||||
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
|
||||
"""
|
||||
|
||||
N_WRITE_ENTRIES: int = 32
|
||||
|
||||
# ======================================
|
||||
# Public methods
|
||||
# ======================================
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pinned_state_dicts: Dict[int, dict] = {}
|
||||
self.async_writers = []
|
||||
|
||||
def _sync_io(self):
|
||||
for writer in self.async_writers:
|
||||
writer.synchronize()
|
||||
writer.fp.close()
|
||||
self.async_writers.clear()
|
||||
|
||||
def _sync_d2h(self):
|
||||
for writer in self.async_writers:
|
||||
writer.sync_before_step()
|
||||
|
||||
def synchronize(self):
|
||||
"""This method must be called before updating the model weights."""
|
||||
self._sync_d2h()
|
||||
|
||||
def __del__(self):
|
||||
self._sync_d2h()
|
||||
self._sync_io()
|
||||
|
||||
def load_model(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
|
||||
) -> Union[nn.Module, ModelWrapper]:
|
||||
@@ -111,6 +137,7 @@ class CheckpointIO(ABC):
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save model to checkpoint.
|
||||
@@ -138,11 +165,21 @@ class CheckpointIO(ABC):
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
||||
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
||||
"""
|
||||
self._sync_io()
|
||||
if use_async and not use_safetensors:
|
||||
logger = get_dist_logger()
|
||||
logger.warning(
|
||||
"Async save is only supported when use_safetensors is set to True. "
|
||||
"Setting use_safetensors to True for async save."
|
||||
)
|
||||
use_safetensors = True
|
||||
|
||||
if shard:
|
||||
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
||||
self.save_sharded_model(
|
||||
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async=use_async
|
||||
)
|
||||
else:
|
||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
|
||||
"""
|
||||
@@ -234,6 +271,7 @@ class CheckpointIO(ABC):
|
||||
prefix: Optional[str],
|
||||
size_per_shard: int,
|
||||
use_safetensors: bool,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save model to sharded checkpoint.
|
||||
@@ -248,7 +286,9 @@ class CheckpointIO(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
def save_unsharded_model(
|
||||
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||
):
|
||||
"""
|
||||
Save model to unsharded checkpoint.
|
||||
|
||||
|
Reference in New Issue
Block a user