Next commit [checkpointio] Unsharded Optimizer Checkpoint for Gemini Plugin (#4141)

* [checkpointio] unsharded optimizer checkpoint for Gemini plugin

* [checkpointio] unsharded optimizer checkpoint for Gemini using all_gather
This commit is contained in:
Baizhou Zhang
2023-07-07 16:33:06 +08:00
committed by GitHub
parent fee32a3b78
commit 58913441a1
9 changed files with 684 additions and 83 deletions

View File

@@ -28,6 +28,7 @@ from .utils import (
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
unwrap_optimizer,
)
__all__ = ['GeneralCheckpointIO']
@@ -59,7 +60,7 @@ class GeneralCheckpointIO(CheckpointIO):
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.optim
optimizer = unwrap_optimizer(optimizer)
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
@@ -96,6 +97,11 @@ class GeneralCheckpointIO(CheckpointIO):
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
"""
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = unwrap_optimizer(optimizer)
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@@ -121,9 +127,8 @@ class GeneralCheckpointIO(CheckpointIO):
shard, current_size = shard_pair
shard_file = get_shard_filename(states_name, idx)
total_size = total_size + current_size
for param_id in shard.keys():
index_file.append_weight_map(str(param_id), shard_file)
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
@@ -177,7 +182,6 @@ class GeneralCheckpointIO(CheckpointIO):
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)