mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[tensor] refactor param op hook (#1097)
* refactor param op hook * add docstr * fix bug
This commit is contained in:
@@ -3,7 +3,7 @@ from colossalai.tensor.const import TensorType
|
||||
import torch
|
||||
from colossalai.tensor import TensorSpec, distspec
|
||||
from copy import copy
|
||||
from colossalai.tensor.param_op_hook import _ParamOpHookWrapper, PreFwdPostBwd, PostFwdPreBwd
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -48,17 +48,17 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
||||
if len(_ParamOpHookWrapper.hooks) > 0:
|
||||
if ParamOpHookManager.has_hook():
|
||||
if not func.__name__.startswith('__'):
|
||||
params = list(filter(lambda arg: isinstance(arg, ColoParameter), args))
|
||||
if kwargs is not None:
|
||||
params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values())))
|
||||
if len(params) > 0:
|
||||
with torch._C.DisableTorchFunction():
|
||||
args = PreFwdPostBwd.apply(params, *args)
|
||||
args = ParamOpHookManager.pre_op(params, *args)
|
||||
ret = super().__torch_function__(func, types, args, kwargs)
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = PostFwdPreBwd.apply(params, ret)
|
||||
ret = ParamOpHookManager.post_op(params, ret)
|
||||
return ret
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user