refactored grad scaler (#338)

This commit is contained in:
Frank Lee
2022-03-09 11:52:43 +08:00
parent 6a3188167c
commit 3d5d64bd10
4 changed files with 135 additions and 0 deletions

View File

@@ -0,0 +1,16 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from .base_grad_scaler import BaseGradScaler
__all__ = ['ConstantGradScaler']
class ConstantGradScaler(BaseGradScaler):
def __init__(self, initial_scale: int, verbose: bool):
super().__init__(initial_scale, verbose)
self.log(f"Constant Gradient Scaler is initialized with scale {self.scale}", ranks=[0])
def update(self, overflow: bool) -> None:
# do nothing to maintain the current scale value
pass