[tensor] refactor param op hook (#1097)

* refactor param op hook

* add docstr

* fix bug
This commit is contained in:
ver217
2022-06-13 16:11:53 +08:00
committed by GitHub
parent 1e9f9c227f
commit 895c1c5ee7
4 changed files with 76 additions and 31 deletions

View File

@@ -1,10 +1,15 @@
import torch
from contextlib import contextmanager
from abc import ABC, abstractmethod
from typing import List, Tuple
from typing import List, Tuple, Any
class ParamOpHook(ABC):
"""Hook which is triggered by each operation when operands contain ColoParameter.
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
of ColoParameter.
"""
@abstractmethod
def pre_forward(self, params: List[torch.Tensor]) -> None:
@@ -23,25 +28,78 @@ class ParamOpHook(ABC):
pass
class _ParamOpHookWrapper:
class ParamOpHookManager:
"""Manage your param op hooks. It only has static methods.
The only static method you should call is ``use_hooks(*hooks)``.
"""
hooks: Tuple[ParamOpHook, ...] = tuple()
@staticmethod
@contextmanager
def use_hooks(*hooks: ParamOpHook):
"""Change the param op hooks you use. Nested calling is allowed.
Example::
>>> with ParamOpHookManager.use_hooks(*hooks):
>>> do_something()
>>> with ParamOpHookManager.use_hooks():
>>> // clear hooks
>>> do_something()
"""
try:
old_param_op_hooks = ParamOpHookManager.hooks
ParamOpHookManager.hooks = hooks
yield
finally:
ParamOpHookManager.hooks = old_param_op_hooks
@staticmethod
def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
hook.pre_forward(params)
@staticmethod
def _trigger_post_forward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
hook.post_forward(params)
@staticmethod
def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
hook.pre_backward(params)
@staticmethod
def _trigger_post_backward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
hook.post_backward(params)
@staticmethod
def pre_op(params: List[torch.Tensor], *args: Any) -> Any:
ParamOpHookManager._trigger_pre_forward(params)
return PreFwdPostBwd.apply(params, *args)
@staticmethod
def post_op(params: List[torch.Tensor], args: Any) -> Any:
ParamOpHookManager._trigger_post_forward(params)
return PostFwdPreBwd.apply(params, args)
@staticmethod
def has_hook() -> bool:
return len(ParamOpHookManager.hooks) > 0
class PreFwdPostBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, *args):
ctx.params = params
for hook in _ParamOpHookWrapper.hooks:
hook.pre_forward(ctx.params)
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grads):
for hook in _ParamOpHookWrapper.hooks:
hook.post_backward(ctx.params)
ParamOpHookManager._trigger_post_backward(ctx.params)
return (None,) + grads
@@ -50,22 +108,9 @@ class PostFwdPreBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, args):
ctx.params = params
for hook in _ParamOpHookWrapper.hooks:
hook.post_forward(params)
return args
@staticmethod
def backward(ctx, *grads):
for hook in _ParamOpHookWrapper.hooks:
hook.pre_backward(ctx.params)
ParamOpHookManager._trigger_pre_backward(ctx.params)
return (None,) + grads
@contextmanager
def use_param_op_hooks(*hooks: ParamOpHook):
try:
old_param_op_hooks = _ParamOpHookWrapper.hooks
_ParamOpHookWrapper.hooks = hooks
yield
finally:
_ParamOpHookWrapper.hooks = old_param_op_hooks