mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
Next commit [checkpointio] Unsharded Optimizer Checkpoint for Gemini Plugin (#4141)
* [checkpointio] unsharded optimizer checkpoint for Gemini plugin * [checkpointio] unsharded optimizer checkpoint for Gemini using all_gather
This commit is contained in:
@@ -28,6 +28,7 @@ from .utils import (
|
||||
shard_model_checkpoint,
|
||||
shard_optimizer_checkpoint,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
unwrap_optimizer,
|
||||
)
|
||||
|
||||
__all__ = ['GeneralCheckpointIO']
|
||||
@@ -59,7 +60,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = optimizer.optim
|
||||
optimizer = unwrap_optimizer(optimizer)
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||
@@ -96,6 +97,11 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
- A group file (pytorch_optim_group.bin) recording information of param_groups
|
||||
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
|
||||
"""
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = unwrap_optimizer(optimizer)
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
@@ -121,9 +127,8 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
shard, current_size = shard_pair
|
||||
shard_file = get_shard_filename(states_name, idx)
|
||||
total_size = total_size + current_size
|
||||
for param_id in shard.keys():
|
||||
index_file.append_weight_map(str(param_id), shard_file)
|
||||
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||
|
||||
@@ -177,7 +182,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
total_size = total_size + shard_pair[1]
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||
|
||||
|
Reference in New Issue
Block a user