mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 16:40:41 +00:00
fix zero optim backward_by_grad and save/load (#1353)
This commit is contained in:
@@ -142,6 +142,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float):
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
# TODO(ver217): fix zero clip grad norm
|
||||
return super().clip_grad_norm(model, max_norm)
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
@@ -150,6 +151,11 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
self.module.backward(loss)
|
||||
|
||||
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
|
||||
# This function is called except the last stage of pipeline parallel
|
||||
# It receives the scaled grad from the previous rank
|
||||
# No need to scale the grad again
|
||||
# Need to unscale when optimizing
|
||||
self.optim_state = OptimState.SCALED
|
||||
self.module.backward_by_grad(tensor, grad)
|
||||
|
||||
def _maybe_move_fp32_params(self):
|
||||
@@ -184,7 +190,18 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
if isinstance(val, torch.Tensor):
|
||||
self.chunk_manager.add_extern_static_tensor(val)
|
||||
|
||||
def state_dict(self):
|
||||
optim_state_dict = super().state_dict()
|
||||
scaler_state_dict = self.grad_scaler.state_dict()
|
||||
optim_state_dict['scaler'] = scaler_state_dict
|
||||
return optim_state_dict
|
||||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
if 'scaler' not in args[0]:
|
||||
self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0])
|
||||
else:
|
||||
scaler_state_dict = args[0].pop('scaler')
|
||||
self.grad_scaler.load_state_dict(scaler_state_dict)
|
||||
super().load_state_dict(*args, **kwargs)
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
|
Reference in New Issue
Block a user