[Gemini] ParamOpHook -> ColoParamOpHook (#2080)

This commit is contained in:
Jiarui Fang
2022-12-05 17:11:06 +08:00
committed by GitHub
parent 4f21c9e8d9
commit b3b89865e2
7 changed files with 37 additions and 36 deletions

View File

@@ -4,7 +4,7 @@ import torch
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.tensor.tensor_spec import ColoTensorSpec
@@ -58,18 +58,18 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):
if ParamOpHookManager.has_hook():
if ColoParamOpHookManager.has_hook():
if not func.__name__.startswith('__'):
if kwargs is None:
kwargs = {}
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
if len(params) > 0:
with torch._C.DisableTorchFunction():
new_args = ParamOpHookManager.pre_op(params, *args, *kwargs.values())
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ParamOpHookManager.post_op(params, ret)
ret = ColoParamOpHookManager.post_op(params, ret)
return ret
return super().__torch_function__(func, types, args, kwargs)