[checkpointio] support huggingface from_pretrained for all plugins (#4606)

This commit is contained in:
Baizhou Zhang
2023-09-04 23:25:01 +08:00
committed by GitHub
parent 0a94fcd351
commit e79b1e80e2
4 changed files with 87 additions and 129 deletions

View File

@@ -18,6 +18,7 @@ from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames,
get_shard_filename,
load_shard_state_dict,
save_config_file,
save_state_dict,
save_state_dict_shards,
)
@@ -111,6 +112,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.module, checkpoint_path)
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")