[checkpointio] optimize zero optim checkpoint io (#4591)

* [zero] update checkpoint io to save memory

* [checkpointio] add device map to save memory
This commit is contained in:
Hongxin Liu
2023-09-04 11:26:45 +08:00
committed by GitHub
parent cfa607080f
commit 63ecafb1fb
4 changed files with 43 additions and 22 deletions

View File

@@ -78,8 +78,6 @@ class GeneralCheckpointIO(CheckpointIO):
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
load_states_into_optimizer(optimizer, state_dict, id_map)
del state_dict
gc.collect()
sharded_optimizer_loading_epilogue(optimizer)