[tensor] refactor param op hook (#1097)

* refactor param op hook

* add docstr

* fix bug
This commit is contained in:
ver217
2022-06-13 16:11:53 +08:00
committed by GitHub
parent 1e9f9c227f
commit 895c1c5ee7
4 changed files with 76 additions and 31 deletions

View File

@@ -3,7 +3,7 @@ from colossalai.tensor.const import TensorType
import torch
from colossalai.tensor import TensorSpec, distspec
from copy import copy
from colossalai.tensor.param_op_hook import _ParamOpHookWrapper, PreFwdPostBwd, PostFwdPreBwd
from colossalai.tensor.param_op_hook import ParamOpHookManager
from typing import Optional
@@ -48,17 +48,17 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):
if len(_ParamOpHookWrapper.hooks) > 0:
if ParamOpHookManager.has_hook():
if not func.__name__.startswith('__'):
params = list(filter(lambda arg: isinstance(arg, ColoParameter), args))
if kwargs is not None:
params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values())))
if len(params) > 0:
with torch._C.DisableTorchFunction():
args = PreFwdPostBwd.apply(params, *args)
args = ParamOpHookManager.pre_op(params, *args)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = PostFwdPreBwd.apply(params, ret)
ret = ParamOpHookManager.post_op(params, ret)
return ret
return super().__torch_function__(func, types, args, kwargs)