mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -11,7 +11,7 @@ from colossalai.interface import ModelWrapper
|
||||
|
||||
from .utils import has_index_file
|
||||
|
||||
__all__ = ['CheckpointIO']
|
||||
__all__ = ["CheckpointIO"]
|
||||
|
||||
|
||||
class CheckpointIO(ABC):
|
||||
@@ -61,10 +61,9 @@ class CheckpointIO(ABC):
|
||||
# ======================================
|
||||
# Public methods
|
||||
# ======================================
|
||||
def load_model(self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
strict: bool = True) -> Union[nn.Module, ModelWrapper]:
|
||||
def load_model(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
|
||||
) -> Union[nn.Module, ModelWrapper]:
|
||||
"""
|
||||
Load model from checkpoint.
|
||||
|
||||
@@ -98,14 +97,16 @@ class CheckpointIO(ABC):
|
||||
|
||||
return origin_model
|
||||
|
||||
def save_model(self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
def save_model(
|
||||
self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
):
|
||||
"""
|
||||
Save model to checkpoint.
|
||||
|
||||
@@ -157,7 +158,7 @@ class CheckpointIO(ABC):
|
||||
|
||||
if Path(checkpoint).is_dir() and not index_file_exists:
|
||||
# if the checkpoint is a directory and there is no index file, raise error
|
||||
raise ValueError(f'Cannot find index file in {checkpoint}')
|
||||
raise ValueError(f"Cannot find index file in {checkpoint}")
|
||||
|
||||
if index_file_exists:
|
||||
# the existence of index file means it is a sharded checkpoint
|
||||
@@ -165,13 +166,15 @@ class CheckpointIO(ABC):
|
||||
else:
|
||||
self.load_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_optimizer(self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor=True,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024):
|
||||
def save_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor=True,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
"""
|
||||
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
|
||||
|
||||
@@ -207,7 +210,6 @@ class CheckpointIO(ABC):
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||
@@ -220,11 +222,17 @@ class CheckpointIO(ABC):
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
|
||||
size_per_shard: int, use_safetensors: bool):
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool,
|
||||
prefix: Optional[str],
|
||||
size_per_shard: int,
|
||||
use_safetensors: bool,
|
||||
):
|
||||
"""
|
||||
Save model to sharded checkpoint.
|
||||
|
||||
@@ -236,7 +244,6 @@ class CheckpointIO(ABC):
|
||||
size_per_shard (int): size per shard in MB.
|
||||
use_safetensors (bool): whether to use safe tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
@@ -249,7 +256,6 @@ class CheckpointIO(ABC):
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
||||
use_safetensors (bool): whether to use safe tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
# ========================================================
|
||||
# Abstract methods for optimizer loading/saving implementation
|
||||
@@ -265,7 +271,6 @@ class CheckpointIO(ABC):
|
||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
prefix (str): prefix for the optimizer checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
@@ -276,11 +281,11 @@ class CheckpointIO(ABC):
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
|
||||
size_per_shard: int):
|
||||
def save_sharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||
):
|
||||
"""
|
||||
Save optimizer to sharded checkpoint.
|
||||
|
||||
@@ -291,7 +296,6 @@ class CheckpointIO(ABC):
|
||||
prefix (str): prefix for the optimizer checkpoint.
|
||||
size_per_shard (int): size per shard in MB.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
|
||||
@@ -303,7 +307,6 @@ class CheckpointIO(ABC):
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
||||
"""
|
||||
pass
|
||||
|
||||
# ============================================
|
||||
# methods for loading and saving lr scheduler
|
||||
|
Reference in New Issue
Block a user