[checkpointio] sharded optimizer checkpoint for DDP plugin (#4002)

This commit is contained in:
Baizhou Zhang
2023-06-16 14:14:05 +08:00
committed by GitHub
parent 725af3eeeb
commit 822c3d4d66
6 changed files with 79 additions and 34 deletions

View File

@@ -139,6 +139,12 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
state_size = 0
isDTensor = False
for state_tensor in state.values():
# When state_tensor is None (e.g., a SGD optimizer with momentum set to 0),
# The calculation of tensor size should be skipped to avoid error.
if state_tensor is None:
continue
# If the states are stored as DTensors, mark isDTensor as true.
if type(state_tensor) == DTensor:
isDTensor = True
@@ -271,7 +277,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
return id_map
def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: dict):
def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict):
r"""Copies states from `state_dict` into an Optimizer object.
Args:
@@ -311,10 +317,16 @@ def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: d
else:
new_states[k] = v
optimzier.state.update(new_states)
optimizer.state.update(new_states)
def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
r"""Do the cleaning up work after state_dict has been loaded into optimizer
Args:
optimizer(Optimizer): An optimizer object whose state has just been loaded.
"""
# Do the cleaning up as in src code of Pytorch.
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
optimizer.defaults.setdefault('differentiable', False)