[gemini] fixed the gemini checkpoint io (#3934)

This commit is contained in:
Frank Lee
2023-06-09 09:48:49 +08:00
committed by GitHub
parent a98e16ed07
commit bd1ab98158
3 changed files with 19 additions and 11 deletions

View File

@@ -99,8 +99,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
save_state_dict(shard, checkpoint_file_path, use_safetensors)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(f"The model is going to be split to checkpoint shards. "
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.write_index_file(save_index_file)
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")