mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-05 19:48:23 +00:00
Fix/format colossalai/engine/paramhooks/(#350)
This commit is contained in:
parent
e83970e3dc
commit
54ee8d1254
@ -1,2 +1,3 @@
|
|||||||
from ._param_hookmgr import BaseParamHookMgr
|
from ._param_hookmgr import BaseParamHookMgr
|
||||||
__all__ = ["BaseParamHookMgr"]
|
|
||||||
|
__all__ = ["BaseParamHookMgr"]
|
||||||
|
@ -2,7 +2,9 @@ from typing import Callable, List
|
|||||||
import torch
|
import torch
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
|
||||||
class BaseParamHookMgr(object):
|
class BaseParamHookMgr(object):
|
||||||
|
|
||||||
def __init__(self, param_list: List[torch.nn.Parameter]) -> None:
|
def __init__(self, param_list: List[torch.nn.Parameter]) -> None:
|
||||||
r"""
|
r"""
|
||||||
register backward hook on every parameters of module
|
register backward hook on every parameters of module
|
||||||
@ -10,17 +12,17 @@ class BaseParamHookMgr(object):
|
|||||||
self._param_list = param_list
|
self._param_list = param_list
|
||||||
self._hook_list = []
|
self._hook_list = []
|
||||||
|
|
||||||
def register_backward_hooks(self, hook_call : Callable) -> None:
|
def register_backward_hooks(self, hook_call: Callable) -> None:
|
||||||
r"""
|
r"""
|
||||||
The hook_call will be called every time a gradient with respect to the a param in self.param_list
|
The hook_call will be called every time a gradient with respect to the a param in self.param_list
|
||||||
is computed.
|
is computed.
|
||||||
The hook should have the following signature:
|
The hook should have the following signature:
|
||||||
```
|
```
|
||||||
hook(param, grad) -> Tensor or None
|
hook(param, grad) -> Tensor or None
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if not torch.is_grad_enabled():
|
if not torch.is_grad_enabled():
|
||||||
return # don't register grad hooks if grad isn't enabled
|
return # don't register grad hooks if grad isn't enabled
|
||||||
for p in self._param_list:
|
for p in self._param_list:
|
||||||
if p.requires_grad and not hasattr(p, '_base_param_hook'):
|
if p.requires_grad and not hasattr(p, '_base_param_hook'):
|
||||||
handle = p.register_hook(functools.partial(hook_call, p))
|
handle = p.register_hook(functools.partial(hook_call, p))
|
||||||
|
Loading…
Reference in New Issue
Block a user