mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Gemini] ParamOpHook -> ColoParamOpHook (#2080)
This commit is contained in:
@@ -4,7 +4,7 @@ import torch
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor.const import TensorType
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
|
||||
|
||||
@@ -58,18 +58,18 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
||||
if ParamOpHookManager.has_hook():
|
||||
if ColoParamOpHookManager.has_hook():
|
||||
if not func.__name__.startswith('__'):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
|
||||
if len(params) > 0:
|
||||
with torch._C.DisableTorchFunction():
|
||||
new_args = ParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||
args, kwargs = replace_args(args, kwargs, new_args)
|
||||
ret = super().__torch_function__(func, types, args, kwargs)
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = ParamOpHookManager.post_op(params, ret)
|
||||
ret = ColoParamOpHookManager.post_op(params, ret)
|
||||
return ret
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user