diff --git a/colossalai/engine/paramhooks/__init__.py b/colossalai/engine/paramhooks/__init__.py index 6bafd7b91..7e4239937 100644 --- a/colossalai/engine/paramhooks/__init__.py +++ b/colossalai/engine/paramhooks/__init__.py @@ -1,2 +1,3 @@ from ._param_hookmgr import BaseParamHookMgr -__all__ = ["BaseParamHookMgr"] \ No newline at end of file + +__all__ = ["BaseParamHookMgr"] diff --git a/colossalai/engine/paramhooks/_param_hookmgr.py b/colossalai/engine/paramhooks/_param_hookmgr.py index 596a965eb..a1b995ccd 100644 --- a/colossalai/engine/paramhooks/_param_hookmgr.py +++ b/colossalai/engine/paramhooks/_param_hookmgr.py @@ -2,7 +2,9 @@ from typing import Callable, List import torch import functools + class BaseParamHookMgr(object): + def __init__(self, param_list: List[torch.nn.Parameter]) -> None: r""" register backward hook on every parameters of module @@ -10,17 +12,17 @@ class BaseParamHookMgr(object): self._param_list = param_list self._hook_list = [] - def register_backward_hooks(self, hook_call : Callable) -> None: + def register_backward_hooks(self, hook_call: Callable) -> None: r""" - The hook_call will be called every time a gradient with respect to the a param in self.param_list - is computed. + The hook_call will be called every time a gradient with respect to the a param in self.param_list + is computed. The hook should have the following signature: ``` hook(param, grad) -> Tensor or None ``` """ if not torch.is_grad_enabled(): - return # don't register grad hooks if grad isn't enabled + return # don't register grad hooks if grad isn't enabled for p in self._param_list: if p.requires_grad and not hasattr(p, '_base_param_hook'): handle = p.register_hook(functools.partial(hook_call, p))