mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 15:57:16 +00:00
fix colo parameter torch function (#1117)
This commit is contained in:
parent
e1620ddac2
commit
f99f56dff4
@ -7,6 +7,23 @@ from colossalai.tensor.param_op_hook import ParamOpHookManager
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def filter_args(func, *args):
|
||||||
|
return [arg for arg in args if func(arg)]
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_args(*args):
|
||||||
|
if len(args) == 1:
|
||||||
|
return args[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def replace_args(args, kwargs, new_args):
|
||||||
|
args = new_args[:len(args)]
|
||||||
|
for k, v in zip(kwargs.keys(), new_args[len(args):]):
|
||||||
|
kwargs[k] = v
|
||||||
|
return unpack_args(args), kwargs
|
||||||
|
|
||||||
|
|
||||||
class ColoParameter(ColoTensor, torch.nn.Parameter):
|
class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||||
r"""A kind of ColoTensor to be considered as a module parameter.
|
r"""A kind of ColoTensor to be considered as a module parameter.
|
||||||
|
|
||||||
@ -50,12 +67,13 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
|||||||
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
||||||
if ParamOpHookManager.has_hook():
|
if ParamOpHookManager.has_hook():
|
||||||
if not func.__name__.startswith('__'):
|
if not func.__name__.startswith('__'):
|
||||||
params = list(filter(lambda arg: isinstance(arg, ColoParameter), args))
|
if kwargs is None:
|
||||||
if kwargs is not None:
|
kwargs = {}
|
||||||
params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values())))
|
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
|
||||||
if len(params) > 0:
|
if len(params) > 0:
|
||||||
with torch._C.DisableTorchFunction():
|
with torch._C.DisableTorchFunction():
|
||||||
args = ParamOpHookManager.pre_op(params, *args)
|
new_args = ParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||||
|
args, kwargs = replace_args(args, kwargs, new_args)
|
||||||
ret = super().__torch_function__(func, types, args, kwargs)
|
ret = super().__torch_function__(func, types, args, kwargs)
|
||||||
with torch._C.DisableTorchFunction():
|
with torch._C.DisableTorchFunction():
|
||||||
ret = ParamOpHookManager.post_op(params, ret)
|
ret = ParamOpHookManager.post_op(params, ret)
|
||||||
|
Loading…
Reference in New Issue
Block a user