mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[booster] torch fsdp fix ckpt (#3788)
This commit is contained in:
@@ -1,26 +1,26 @@
|
||||
from pathlib import Path
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
import logging
|
||||
import os
|
||||
import gc
|
||||
from typing import Optional, Iterator, OrderedDict, Tuple
|
||||
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .utils import (
|
||||
has_index_file,
|
||||
load_state_dict,
|
||||
save_state_dict,
|
||||
is_safetensors_available,
|
||||
shard_checkpoint,
|
||||
load_shard_state_dict,
|
||||
load_state_dict_into_model,
|
||||
get_base_filenames,
|
||||
get_shard_filename,
|
||||
get_base_filenames
|
||||
)
|
||||
has_index_file,
|
||||
is_safetensors_available,
|
||||
load_shard_state_dict,
|
||||
load_state_dict,
|
||||
load_state_dict_into_model,
|
||||
save_state_dict,
|
||||
shard_checkpoint,
|
||||
)
|
||||
|
||||
__all__ = ['GeneralCheckpointIO']
|
||||
|
||||
@@ -29,6 +29,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
"""
|
||||
Checkpoint IO
|
||||
"""
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
model.load_state_dict(checkpoint, strict=strict)
|
||||
@@ -69,19 +70,23 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
# TODO(FrankLeeeee): handle distributed tensors
|
||||
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
|
||||
|
||||
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False,
|
||||
variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False):
|
||||
"""
|
||||
def save_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
implement this method as it can be supported by Huggingface model,
|
||||
save shard model, save model to multiple files
|
||||
"""
|
||||
if os.path.isfile(checkpoint_path):
|
||||
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
|
||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# shard checkpoint
|
||||
state_dict = model.state_dict()
|
||||
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||
@@ -95,21 +100,22 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
total_size = total_size + shard_pair[1]
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(
|
||||
f"The model is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
logging.info(f"The model is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False,
|
||||
use_safetensors: bool = False, load_sub_module: bool = True):
|
||||
def load_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True):
|
||||
"""
|
||||
load shard model, load model from multiple files
|
||||
"""
|
||||
@@ -119,7 +125,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
||||
|
||||
|
||||
# read checkpoint index file
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
|
||||
@@ -134,10 +140,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
if strict:
|
||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||
if len(remain_keys) > 0:
|
||||
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in missing_keys))
|
||||
error_msgs = 'Missing key(s) in state_dict: {}. '.format(', '.join(
|
||||
'"{}"'.format(k) for k in missing_keys))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
Reference in New Issue
Block a user