[booster] torch fsdp fix ckpt (#3788)

This commit is contained in:
wukong1992
2023-05-23 16:58:45 +08:00
committed by GitHub
parent 9265f2d4d7
commit 6b305a99d6
5 changed files with 230 additions and 186 deletions

View File

@@ -1,26 +1,26 @@
from pathlib import Path
import gc
import logging
import os
from functools import reduce
from pathlib import Path
from typing import Iterator, Optional, OrderedDict, Tuple
import torch.nn as nn
from torch.optim import Optimizer
import logging
import os
import gc
from typing import Optional, Iterator, OrderedDict, Tuple
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
has_index_file,
load_state_dict,
save_state_dict,
is_safetensors_available,
shard_checkpoint,
load_shard_state_dict,
load_state_dict_into_model,
get_base_filenames,
get_shard_filename,
get_base_filenames
)
has_index_file,
is_safetensors_available,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
save_state_dict,
shard_checkpoint,
)
__all__ = ['GeneralCheckpointIO']
@@ -29,6 +29,7 @@ class GeneralCheckpointIO(CheckpointIO):
"""
Checkpoint IO
"""
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
checkpoint = load_state_dict(checkpoint)
model.load_state_dict(checkpoint, strict=strict)
@@ -69,19 +70,23 @@ class GeneralCheckpointIO(CheckpointIO):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False,
variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False):
"""
def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
implement this method as it can be supported by Huggingface model,
save shard model, save model to multiple files
"""
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
# shard checkpoint
state_dict = model.state_dict()
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
@@ -95,21 +100,22 @@ class GeneralCheckpointIO(CheckpointIO):
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(
f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
logging.info(f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False,
use_safetensors: bool = False, load_sub_module: bool = True):
def load_sharded_model(self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True):
"""
load shard model, load model from multiple files
"""
@@ -119,7 +125,7 @@ class GeneralCheckpointIO(CheckpointIO):
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# read checkpoint index file
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
@@ -134,10 +140,7 @@ class GeneralCheckpointIO(CheckpointIO):
if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
if len(remain_keys) > 0:
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys))
error_msgs = 'Missing key(s) in state_dict: {}. '.format(', '.join(
'"{}"'.format(k) for k in missing_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
self.__class__.__name__, "\n\t".join(error_msgs)))