mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[checkpointio] sharded optimizer checkpoint for DDP plugin (#4002)
This commit is contained in:
@@ -139,6 +139,12 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
||||
state_size = 0
|
||||
isDTensor = False
|
||||
for state_tensor in state.values():
|
||||
|
||||
# When state_tensor is None (e.g., a SGD optimizer with momentum set to 0),
|
||||
# The calculation of tensor size should be skipped to avoid error.
|
||||
if state_tensor is None:
|
||||
continue
|
||||
|
||||
# If the states are stored as DTensors, mark isDTensor as true.
|
||||
if type(state_tensor) == DTensor:
|
||||
isDTensor = True
|
||||
@@ -271,7 +277,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
|
||||
return id_map
|
||||
|
||||
|
||||
def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: dict):
|
||||
def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict):
|
||||
r"""Copies states from `state_dict` into an Optimizer object.
|
||||
|
||||
Args:
|
||||
@@ -311,10 +317,16 @@ def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: d
|
||||
else:
|
||||
new_states[k] = v
|
||||
|
||||
optimzier.state.update(new_states)
|
||||
optimizer.state.update(new_states)
|
||||
|
||||
|
||||
def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
|
||||
r"""Do the cleaning up work after state_dict has been loaded into optimizer
|
||||
|
||||
Args:
|
||||
optimizer(Optimizer): An optimizer object whose state has just been loaded.
|
||||
"""
|
||||
|
||||
# Do the cleaning up as in src code of Pytorch.
|
||||
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
optimizer.defaults.setdefault('differentiable', False)
|
||||
|
Reference in New Issue
Block a user