mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[checkpointio] support huggingface from_pretrained for all plugins (#4606)
This commit is contained in:
@@ -18,6 +18,7 @@ from colossalai.checkpoint_io.utils import (
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
load_shard_state_dict,
|
||||
save_config_file,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
)
|
||||
@@ -111,6 +112,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
if self.coordinator.is_master():
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model.module, checkpoint_path)
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
Reference in New Issue
Block a user