[hotfix] fix sharded optim step and clip_grad_norm (#1226)

This commit is contained in:
ver217
2022-07-08 13:34:48 +08:00
committed by GitHub
parent f071b500b6
commit a45ddf2d5f
2 changed files with 10 additions and 4 deletions

View File

@@ -169,21 +169,27 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.model.backward(loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
# This function is called except the last stage of pipeline parallel
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
self.optim_state = OptimState.SCALED
self.model.backward_by_grad(tensor, grad)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.optim_state == OptimState.SCALED:
self._prepare_grads()
self._unscale_grads()
return super().clip_grad_norm(model, max_norm)
def step(self, *args, **kwargs):
self._prepare_grads()
self._maybe_move_fp32_shards()
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._prepare_grads()
self._unscale_grads()
self._maybe_move_fp32_shards()
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)