mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 02:20:49 +00:00
[ckpt] Add async ckpt api (#6136)
* fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix
This commit is contained in:
@@ -371,7 +371,11 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
||||
# ======================================
|
||||
|
||||
|
||||
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
|
||||
def save_state_dict(
|
||||
state_dict: dict,
|
||||
checkpoint_file_path: str,
|
||||
use_safetensors: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Save state dict to checkpoint.
|
||||
|
||||
@@ -581,14 +585,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
||||
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
|
||||
if use_safetensors:
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
with safe_open(checkpoint_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata["format"] != "pt":
|
||||
raise NotImplementedError(
|
||||
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
|
||||
)
|
||||
return safe_load_file(checkpoint_file)
|
||||
else:
|
||||
return torch.load(checkpoint_file, map_location=torch.device("cpu"))
|
||||
|
Reference in New Issue
Block a user