mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 12:14:02 +00:00
[hotfix] fix memory leak in zero (#781)
This commit is contained in:
@@ -12,7 +12,7 @@ __all__ = ['BaseGradScaler']
|
||||
|
||||
class BaseGradScaler(ABC):
|
||||
|
||||
def __init__(self, initial_scale: int, verbose: bool):
|
||||
def __init__(self, initial_scale: float, verbose: bool):
|
||||
assert initial_scale > 0
|
||||
self._scale = torch.cuda.FloatTensor([initial_scale])
|
||||
self._verbose = verbose
|
||||
@@ -31,6 +31,7 @@ class BaseGradScaler(ABC):
|
||||
def state_dict(self) -> Dict:
|
||||
state_dict = dict()
|
||||
state_dict['scale'] = self.scale
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: Dict) -> None:
|
||||
self._scale = state_dict['scale']
|
||||
|
Reference in New Issue
Block a user