mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[Gemini] ParamOpHook -> ColoParamOpHook (#2080)
This commit is contained in:
@@ -8,7 +8,7 @@ from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
|
||||
|
||||
class ParamOpHook(ABC):
|
||||
class ColoParamOpHook(ABC):
|
||||
"""Hook which is triggered by each operation when operands contain ColoParameter.
|
||||
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
|
||||
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
|
||||
@@ -32,68 +32,68 @@ class ParamOpHook(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class ParamOpHookManager:
|
||||
class ColoParamOpHookManager:
|
||||
"""Manage your param op hooks. It only has static methods.
|
||||
The only static method you should call is ``use_hooks(*hooks)``.
|
||||
"""
|
||||
hooks: Tuple[ParamOpHook, ...] = tuple()
|
||||
hooks: Tuple[ColoParamOpHook, ...] = tuple()
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def use_hooks(*hooks: ParamOpHook):
|
||||
def use_hooks(*hooks: ColoParamOpHook):
|
||||
"""Change the param op hooks you use. Nested calling is allowed.
|
||||
|
||||
Example:
|
||||
>>> with ParamOpHookManager.use_hooks(*hooks):
|
||||
>>> with ColoParamOpHookManager.use_hooks(*hooks):
|
||||
>>> do_something()
|
||||
>>> with ParamOpHookManager.use_hooks():
|
||||
>>> with ColoParamOpHookManager.use_hooks():
|
||||
>>> // clear hooks
|
||||
>>> do_something()
|
||||
"""
|
||||
try:
|
||||
old_param_op_hooks = ParamOpHookManager.hooks
|
||||
ParamOpHookManager.hooks = hooks
|
||||
old_param_op_hooks = ColoParamOpHookManager.hooks
|
||||
ColoParamOpHookManager.hooks = hooks
|
||||
yield
|
||||
finally:
|
||||
ParamOpHookManager.hooks = old_param_op_hooks
|
||||
ColoParamOpHookManager.hooks = old_param_op_hooks
|
||||
|
||||
@staticmethod
|
||||
def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
|
||||
for hook in ParamOpHookManager.hooks:
|
||||
for hook in ColoParamOpHookManager.hooks:
|
||||
hook.pre_forward(params)
|
||||
|
||||
@staticmethod
|
||||
def _trigger_post_forward(params: List[torch.Tensor]) -> None:
|
||||
for hook in ParamOpHookManager.hooks:
|
||||
for hook in ColoParamOpHookManager.hooks:
|
||||
hook.post_forward(params)
|
||||
|
||||
@staticmethod
|
||||
def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
|
||||
for hook in ParamOpHookManager.hooks:
|
||||
for hook in ColoParamOpHookManager.hooks:
|
||||
hook.pre_backward(params)
|
||||
|
||||
@staticmethod
|
||||
def _trigger_post_backward(params: List[torch.Tensor]) -> None:
|
||||
for hook in ParamOpHookManager.hooks:
|
||||
for hook in ColoParamOpHookManager.hooks:
|
||||
hook.post_backward(params)
|
||||
|
||||
@staticmethod
|
||||
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
|
||||
ParamOpHookManager._trigger_pre_forward(params)
|
||||
ColoParamOpHookManager._trigger_pre_forward(params)
|
||||
args_info = _get_colo_tensors_info(*args)
|
||||
rets = PreFwdPostBwd.apply(params, *args)
|
||||
return _update_colo_tensors(args_info, *rets)
|
||||
|
||||
@staticmethod
|
||||
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
|
||||
ParamOpHookManager._trigger_post_forward(params)
|
||||
ColoParamOpHookManager._trigger_post_forward(params)
|
||||
arg_info = _get_colo_tensors_info(arg)
|
||||
ret = PostFwdPreBwd.apply(params, arg)
|
||||
return _unpack_args(_update_colo_tensors(arg_info, ret))
|
||||
|
||||
@staticmethod
|
||||
def has_hook() -> bool:
|
||||
return len(ParamOpHookManager.hooks) > 0
|
||||
return len(ColoParamOpHookManager.hooks) > 0
|
||||
|
||||
|
||||
class PreFwdPostBwd(torch.autograd.Function):
|
||||
@@ -105,7 +105,7 @@ class PreFwdPostBwd(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
ParamOpHookManager._trigger_post_backward(ctx.params)
|
||||
ColoParamOpHookManager._trigger_post_backward(ctx.params)
|
||||
return (None,) + grads
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ class PostFwdPreBwd(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
ParamOpHookManager._trigger_pre_backward(ctx.params)
|
||||
ColoParamOpHookManager._trigger_pre_backward(ctx.params)
|
||||
return (None,) + grads
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user