diff --git a/colossalai/zero/zero_optimizer.py b/colossalai/zero/zero_optimizer.py index d93a6301d..245f69008 100644 --- a/colossalai/zero/zero_optimizer.py +++ b/colossalai/zero/zero_optimizer.py @@ -142,6 +142,7 @@ class ZeroOptimizer(ColossalaiOptimizer): def clip_grad_norm(self, model: torch.nn.Module, max_norm: float): if self.optim_state == OptimState.SCALED: self._unscale_grads() + # TODO(ver217): fix zero clip grad norm return super().clip_grad_norm(model, max_norm) def backward(self, loss: torch.Tensor): @@ -150,6 +151,11 @@ class ZeroOptimizer(ColossalaiOptimizer): self.module.backward(loss) def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): + # This function is called except the last stage of pipeline parallel + # It receives the scaled grad from the previous rank + # No need to scale the grad again + # Need to unscale when optimizing + self.optim_state = OptimState.SCALED self.module.backward_by_grad(tensor, grad) def _maybe_move_fp32_params(self): @@ -184,7 +190,18 @@ class ZeroOptimizer(ColossalaiOptimizer): if isinstance(val, torch.Tensor): self.chunk_manager.add_extern_static_tensor(val) + def state_dict(self): + optim_state_dict = super().state_dict() + scaler_state_dict = self.grad_scaler.state_dict() + optim_state_dict['scaler'] = scaler_state_dict + return optim_state_dict + def load_state_dict(self, *args, **kwargs): + if 'scaler' not in args[0]: + self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0]) + else: + scaler_state_dict = args[0].pop('scaler') + self.grad_scaler.load_state_dict(scaler_state_dict) super().load_state_dict(*args, **kwargs) for group in self.optim.param_groups: for p in group['params']: