diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 8e3e5f5d0..02e7bc45e 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -7,6 +7,23 @@ from colossalai.tensor.param_op_hook import ParamOpHookManager from typing import Optional +def filter_args(func, *args): + return [arg for arg in args if func(arg)] + + +def unpack_args(*args): + if len(args) == 1: + return args[0] + return args + + +def replace_args(args, kwargs, new_args): + args = new_args[:len(args)] + for k, v in zip(kwargs.keys(), new_args[len(args):]): + kwargs[k] = v + return unpack_args(args), kwargs + + class ColoParameter(ColoTensor, torch.nn.Parameter): r"""A kind of ColoTensor to be considered as a module parameter. @@ -50,12 +67,13 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): def __torch_function__(cls, func, types, args=..., kwargs=None): if ParamOpHookManager.has_hook(): if not func.__name__.startswith('__'): - params = list(filter(lambda arg: isinstance(arg, ColoParameter), args)) - if kwargs is not None: - params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values()))) + if kwargs is None: + kwargs = {} + params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values()) if len(params) > 0: with torch._C.DisableTorchFunction(): - args = ParamOpHookManager.pre_op(params, *args) + new_args = ParamOpHookManager.pre_op(params, *args, *kwargs.values()) + args, kwargs = replace_args(args, kwargs, new_args) ret = super().__torch_function__(func, types, args, kwargs) with torch._C.DisableTorchFunction(): ret = ParamOpHookManager.post_op(params, ret)