[hotfix] fix memory leak in zero (#781)

This commit is contained in:
HELSON
2022-04-18 13:57:03 +08:00
committed by GitHub
parent 4b01da24cd
commit 4c4388c46e
6 changed files with 32 additions and 36 deletions

View File

@@ -12,7 +12,7 @@ __all__ = ['BaseGradScaler']
class BaseGradScaler(ABC):
def __init__(self, initial_scale: int, verbose: bool):
def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0
self._scale = torch.cuda.FloatTensor([initial_scale])
self._verbose = verbose
@@ -31,6 +31,7 @@ class BaseGradScaler(ABC):
def state_dict(self) -> Dict:
state_dict = dict()
state_dict['scale'] = self.scale
return state_dict
def load_state_dict(self, state_dict: Dict) -> None:
self._scale = state_dict['scale']