[checkpointio] support huggingface from_pretrained for all plugins (#4606)

This commit is contained in:
Baizhou Zhang
2023-09-04 23:25:01 +08:00
committed by GitHub
parent 0a94fcd351
commit e79b1e80e2
4 changed files with 87 additions and 129 deletions

View File

@@ -23,6 +23,7 @@ from .utils import (
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
@@ -185,6 +186,7 @@ class GeneralCheckpointIO(CheckpointIO):
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint_path, is_master=True)
logging.info(f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")