mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-04 02:57:20 +00:00
[checkpoint] sharded optim save/load grad scaler (#1350)
This commit is contained in:
parent
05fae1fd56
commit
ce470ba37e
@ -363,7 +363,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||||||
|
|
||||||
self.master_params[p].trans_state(TensorState.HOLD)
|
self.master_params[p].trans_state(TensorState.HOLD)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
optim_state_dict = super().state_dict()
|
||||||
|
scaler_state_dict = self.grad_scaler.state_dict()
|
||||||
|
optim_state_dict['scaler'] = scaler_state_dict
|
||||||
|
return optim_state_dict
|
||||||
|
|
||||||
def load_state_dict(self, *args, **kwargs):
|
def load_state_dict(self, *args, **kwargs):
|
||||||
|
if 'scaler' not in args[0]:
|
||||||
|
self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0])
|
||||||
|
else:
|
||||||
|
scaler_state_dict = args[0].pop('scaler')
|
||||||
|
self.grad_scaler.load_state_dict(scaler_state_dict)
|
||||||
super().load_state_dict(*args, **kwargs)
|
super().load_state_dict(*args, **kwargs)
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
|
Loading…
Reference in New Issue
Block a user