mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[tensor] refactor param op hook (#1097)
* refactor param op hook * add docstr * fix bug
This commit is contained in:
@@ -1,10 +1,15 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
|
||||
class ParamOpHook(ABC):
|
||||
"""Hook which is triggered by each operation when operands contain ColoParameter.
|
||||
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
|
||||
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
|
||||
of ColoParameter.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
||||
@@ -23,25 +28,78 @@ class ParamOpHook(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class _ParamOpHookWrapper:
|
||||
class ParamOpHookManager:
|
||||
"""Manage your param op hooks. It only has static methods.
|
||||
The only static method you should call is ``use_hooks(*hooks)``.
|
||||
"""
|
||||
hooks: Tuple[ParamOpHook, ...] = tuple()
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def use_hooks(*hooks: ParamOpHook):
|
||||
"""Change the param op hooks you use. Nested calling is allowed.
|
||||
|
||||
Example::
|
||||
>>> with ParamOpHookManager.use_hooks(*hooks):
|
||||
>>> do_something()
|
||||
>>> with ParamOpHookManager.use_hooks():
|
||||
>>> // clear hooks
|
||||
>>> do_something()
|
||||
"""
|
||||
try:
|
||||
old_param_op_hooks = ParamOpHookManager.hooks
|
||||
ParamOpHookManager.hooks = hooks
|
||||
yield
|
||||
finally:
|
||||
ParamOpHookManager.hooks = old_param_op_hooks
|
||||
|
||||
@staticmethod
|
||||
def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
|
||||
for hook in ParamOpHookManager.hooks:
|
||||
hook.pre_forward(params)
|
||||
|
||||
@staticmethod
|
||||
def _trigger_post_forward(params: List[torch.Tensor]) -> None:
|
||||
for hook in ParamOpHookManager.hooks:
|
||||
hook.post_forward(params)
|
||||
|
||||
@staticmethod
|
||||
def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
|
||||
for hook in ParamOpHookManager.hooks:
|
||||
hook.pre_backward(params)
|
||||
|
||||
@staticmethod
|
||||
def _trigger_post_backward(params: List[torch.Tensor]) -> None:
|
||||
for hook in ParamOpHookManager.hooks:
|
||||
hook.post_backward(params)
|
||||
|
||||
@staticmethod
|
||||
def pre_op(params: List[torch.Tensor], *args: Any) -> Any:
|
||||
ParamOpHookManager._trigger_pre_forward(params)
|
||||
return PreFwdPostBwd.apply(params, *args)
|
||||
|
||||
@staticmethod
|
||||
def post_op(params: List[torch.Tensor], args: Any) -> Any:
|
||||
ParamOpHookManager._trigger_post_forward(params)
|
||||
return PostFwdPreBwd.apply(params, args)
|
||||
|
||||
@staticmethod
|
||||
def has_hook() -> bool:
|
||||
return len(ParamOpHookManager.hooks) > 0
|
||||
|
||||
|
||||
class PreFwdPostBwd(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, params, *args):
|
||||
ctx.params = params
|
||||
for hook in _ParamOpHookWrapper.hooks:
|
||||
hook.pre_forward(ctx.params)
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
for hook in _ParamOpHookWrapper.hooks:
|
||||
hook.post_backward(ctx.params)
|
||||
ParamOpHookManager._trigger_post_backward(ctx.params)
|
||||
return (None,) + grads
|
||||
|
||||
|
||||
@@ -50,22 +108,9 @@ class PostFwdPreBwd(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, params, args):
|
||||
ctx.params = params
|
||||
for hook in _ParamOpHookWrapper.hooks:
|
||||
hook.post_forward(params)
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
for hook in _ParamOpHookWrapper.hooks:
|
||||
hook.pre_backward(ctx.params)
|
||||
ParamOpHookManager._trigger_pre_backward(ctx.params)
|
||||
return (None,) + grads
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_param_op_hooks(*hooks: ParamOpHook):
|
||||
try:
|
||||
old_param_op_hooks = _ParamOpHookWrapper.hooks
|
||||
_ParamOpHookWrapper.hooks = hooks
|
||||
yield
|
||||
finally:
|
||||
_ParamOpHookWrapper.hooks = old_param_op_hooks
|
||||
|
Reference in New Issue
Block a user