From ce470ba37ed91b74c13ab5a81aaa64e70285e125 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 21 Jul 2022 15:21:21 +0800 Subject: [PATCH] [checkpoint] sharded optim save/load grad scaler (#1350) --- colossalai/zero/sharded_optim/sharded_optim_v2.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 194cc165e..091a5b274 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -363,7 +363,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self.master_params[p].trans_state(TensorState.HOLD) + def state_dict(self): + optim_state_dict = super().state_dict() + scaler_state_dict = self.grad_scaler.state_dict() + optim_state_dict['scaler'] = scaler_state_dict + return optim_state_dict + def load_state_dict(self, *args, **kwargs): + if 'scaler' not in args[0]: + self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0]) + else: + scaler_state_dict = args[0].pop('scaler') + self.grad_scaler.load_state_dict(scaler_state_dict) super().load_state_dict(*args, **kwargs) for group in self.optim.param_groups: for p in group['params']: