Fix/format colossalai/engine/paramhooks/(#350)

This commit is contained in:
Xu Kai 2022-03-09 17:28:17 +08:00 committed by Frank Lee
parent e83970e3dc
commit 54ee8d1254
2 changed files with 8 additions and 5 deletions

View File

@ -1,2 +1,3 @@
from ._param_hookmgr import BaseParamHookMgr from ._param_hookmgr import BaseParamHookMgr
__all__ = ["BaseParamHookMgr"]
__all__ = ["BaseParamHookMgr"]

View File

@ -2,7 +2,9 @@ from typing import Callable, List
import torch import torch
import functools import functools
class BaseParamHookMgr(object): class BaseParamHookMgr(object):
def __init__(self, param_list: List[torch.nn.Parameter]) -> None: def __init__(self, param_list: List[torch.nn.Parameter]) -> None:
r""" r"""
register backward hook on every parameters of module register backward hook on every parameters of module
@ -10,17 +12,17 @@ class BaseParamHookMgr(object):
self._param_list = param_list self._param_list = param_list
self._hook_list = [] self._hook_list = []
def register_backward_hooks(self, hook_call : Callable) -> None: def register_backward_hooks(self, hook_call: Callable) -> None:
r""" r"""
The hook_call will be called every time a gradient with respect to the a param in self.param_list The hook_call will be called every time a gradient with respect to the a param in self.param_list
is computed. is computed.
The hook should have the following signature: The hook should have the following signature:
``` ```
hook(param, grad) -> Tensor or None hook(param, grad) -> Tensor or None
``` ```
""" """
if not torch.is_grad_enabled(): if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled return # don't register grad hooks if grad isn't enabled
for p in self._param_list: for p in self._param_list:
if p.requires_grad and not hasattr(p, '_base_param_hook'): if p.requires_grad and not hasattr(p, '_base_param_hook'):
handle = p.register_hook(functools.partial(hook_call, p)) handle = p.register_hook(functools.partial(hook_call, p))