mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[checkpoint] refactored the API and added safetensors support (#3427)
* [checkpoint] refactored the API and added safetensors support * polish code
This commit is contained in:
@@ -4,42 +4,67 @@ import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .utils import has_index_file, load_state_dict, save_state_dict
|
||||
|
||||
__all__ = ['GeneralCheckpointIO']
|
||||
|
||||
|
||||
class GeneralCheckpointIO(CheckpointIO):
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
||||
index_file_path = self.get_sharded_checkpoint_index_file(checkpoint)
|
||||
def load_sharded_model(self, model: nn.Module, index_file_path: Path, strict: bool):
|
||||
# load the index file
|
||||
index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||
|
||||
# iterate over the shard checkpoint files
|
||||
# and load each
|
||||
shard_files = self.get_checkpoint_shard_filenames(index_file_path)
|
||||
for shard_file in shard_files:
|
||||
shard_checkpoint = self.load_state_dict(shard_file)
|
||||
index_file.assert_no_dtensor_checkpoint()
|
||||
checkpoint_file_list, _ = index_file.get_checkpoint_fileanames()
|
||||
for shard_file in checkpoint_file_list:
|
||||
shard_checkpoint = load_state_dict(shard_file)
|
||||
model.load_state_dict(shard_checkpoint, strict=strict)
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
||||
checkpoint = self.load_state_dict(str(checkpoint))
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
model.load_state_dict(checkpoint, strict=strict)
|
||||
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: Path, gather_dtensor: bool, prefix: str,
|
||||
size_per_shard: int, use_safetensors: bool):
|
||||
# TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model
|
||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
|
||||
self.save_checkpoint(model.state_dict(), checkpoint)
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
state_dict = model.state_dict()
|
||||
|
||||
# TODO(FrankLeeeee): add support for gather_dtensor
|
||||
if gather_dtensor:
|
||||
pass
|
||||
|
||||
# save the checkpoint
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
checkpoint = self.load_state_dict(checkpoint)
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: Path,
|
||||
gather_dtensor: bool,
|
||||
prefix: str,
|
||||
size_per_shard: int,
|
||||
):
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
self.save_checkpoint(optimizer.state_dict(), checkpoint)
|
||||
def save_unsharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: Path,
|
||||
gather_dtensor: bool,
|
||||
):
|
||||
# TODO(FrankLeeeee): handle distributed tensors
|
||||
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
|
||||
|
Reference in New Issue
Block a user