[hotfix] fix param op hook (#1131)

* fix param op hook

* update zero tp test

* fix bugs
This commit is contained in:
ver217
2022-06-17 16:12:05 +08:00
committed by GitHub
parent a1a7899cae
commit 789cad301b
3 changed files with 74 additions and 20 deletions

View File

@@ -11,17 +11,11 @@ def filter_args(func, *args):
return [arg for arg in args if func(arg)]
def unpack_args(*args):
if len(args) == 1:
return args[0]
return args
def replace_args(args, kwargs, new_args):
args = new_args[:len(args)]
for k, v in zip(kwargs.keys(), new_args[len(args):]):
kwargs[k] = v
return unpack_args(args), kwargs
return tuple(args), kwargs
class ColoParameter(ColoTensor, torch.nn.Parameter):