mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[hotfix] fix implement error in diffusers
This commit is contained in:
@@ -141,7 +141,25 @@ def _is_grad_tensor(obj) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _has_grad_tensor(obj) -> bool:
|
||||
if isinstance(obj, tuple) or isinstance(obj, list):
|
||||
for x in obj:
|
||||
if _has_grad_tensor(x):
|
||||
return True
|
||||
return False
|
||||
elif isinstance(obj, dict):
|
||||
for x in obj.values():
|
||||
if _has_grad_tensor(x):
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
return _is_grad_tensor(obj)
|
||||
|
||||
|
||||
def _get_grad_args(*args):
|
||||
# if there is no grad tensors, do nothing
|
||||
if not _has_grad_tensor(args):
|
||||
return args, None
|
||||
# returns the identical args if there is a grad tensor
|
||||
for obj in args:
|
||||
if _is_grad_tensor(obj):
|
||||
|
Reference in New Issue
Block a user