[hotfix] fix implement error in diffusers

This commit is contained in:
1SAA
2023-01-06 18:37:18 +08:00
parent 48d33b1b17
commit 33f3023e19
2 changed files with 41 additions and 21 deletions

View File

@@ -141,7 +141,25 @@ def _is_grad_tensor(obj) -> bool:
return False
def _has_grad_tensor(obj) -> bool:
if isinstance(obj, tuple) or isinstance(obj, list):
for x in obj:
if _has_grad_tensor(x):
return True
return False
elif isinstance(obj, dict):
for x in obj.values():
if _has_grad_tensor(x):
return True
return False
else:
return _is_grad_tensor(obj)
def _get_grad_args(*args):
# if there is no grad tensors, do nothing
if not _has_grad_tensor(args):
return args, None
# returns the identical args if there is a grad tensor
for obj in args:
if _is_grad_tensor(obj):