diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 194cc165e..091a5b274 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -363,7 +363,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self.master_params[p].trans_state(TensorState.HOLD) + def state_dict(self): + optim_state_dict = super().state_dict() + scaler_state_dict = self.grad_scaler.state_dict() + optim_state_dict['scaler'] = scaler_state_dict + return optim_state_dict + def load_state_dict(self, *args, **kwargs): + if 'scaler' not in args[0]: + self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0]) + else: + scaler_state_dict = args[0].pop('scaler') + self.grad_scaler.load_state_dict(scaler_state_dict) super().load_state_dict(*args, **kwargs) for group in self.optim.param_groups: for p in group['params']: