[hotfix] fix implement error in diffusers

This commit is contained in:
1SAA 2023-01-06 18:37:18 +08:00
parent 48d33b1b17
commit 33f3023e19
2 changed files with 41 additions and 21 deletions

View File

@ -141,7 +141,25 @@ def _is_grad_tensor(obj) -> bool:
return False return False
def _has_grad_tensor(obj) -> bool:
if isinstance(obj, tuple) or isinstance(obj, list):
for x in obj:
if _has_grad_tensor(x):
return True
return False
elif isinstance(obj, dict):
for x in obj.values():
if _has_grad_tensor(x):
return True
return False
else:
return _is_grad_tensor(obj)
def _get_grad_args(*args): def _get_grad_args(*args):
# if there is no grad tensors, do nothing
if not _has_grad_tensor(args):
return args, None
# returns the identical args if there is a grad tensor # returns the identical args if there is a grad tensor
for obj in args: for obj in args:
if _is_grad_tensor(obj): if _is_grad_tensor(obj):

View File

@ -7,27 +7,22 @@
# #
# thanks! # thanks!
import os
import math import math
import os
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
from einops import repeat from einops import repeat
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear": if schedule == "linear":
betas = ( betas = (torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64)**2)
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
elif schedule == "cosine": elif schedule == "cosine":
timesteps = ( timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s)
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2) alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0] alphas = alphas / alphas[0]
@ -37,7 +32,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
elif schedule == "sqrt_linear": elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt": elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5
else: else:
raise ValueError(f"schedule '{schedule}' unknown.") raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy() return betas.numpy()
@ -48,7 +43,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
c = num_ddpm_timesteps // num_ddim_timesteps c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad': elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps))**2).astype(int)
else: else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
@ -110,21 +105,26 @@ def checkpoint(func, inputs, params, flag):
:param flag: if False, disable gradient checkpointing. :param flag: if False, disable gradient checkpointing.
""" """
if flag: if flag:
args = tuple(inputs) + tuple(params) from torch.utils.checkpoint import checkpoint as torch_checkpoint
return CheckpointFunction.apply(func, len(inputs), *args) return torch_checkpoint(func, *inputs)
# args = tuple(inputs) + tuple(params)
# return CheckpointFunction.apply(func, len(inputs), *args)
else: else:
return func(*inputs) return func(*inputs)
class CheckpointFunction(torch.autograd.Function): class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, run_function, length, *args): def forward(ctx, run_function, length, *args):
ctx.run_function = run_function ctx.run_function = run_function
ctx.input_tensors = list(args[:length]) ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:]) ctx.input_params = list(args[length:])
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), ctx.gpu_autocast_kwargs = {
"dtype": torch.get_autocast_gpu_dtype(), "enabled": torch.is_autocast_enabled(),
"cache_enabled": torch.is_autocast_cache_enabled()} "dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()
}
with torch.no_grad(): with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors) output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors return output_tensors
@ -162,9 +162,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
""" """
if not repeat_only: if not repeat_only:
half = dim // 2 half = dim // 2
freqs = torch.exp( freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) /
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half half).to(device=timesteps.device)
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None] args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: if dim % 2:
@ -211,14 +210,17 @@ def normalization(channels):
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module): class SiLU(nn.Module):
def forward(self, x): def forward(self, x):
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm): class GroupNorm32(nn.GroupNorm):
def forward(self, x): def forward(self, x):
return super().forward(x.float()).type(x.dtype) return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs): def conv_nd(dims, *args, **kwargs):
""" """
Create a 1D, 2D, or 3D convolution module. Create a 1D, 2D, or 3D convolution module.
@ -268,4 +270,4 @@ class HybridConditioner(nn.Module):
def noise_like(shape, device, repeat=False): def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device) noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise() return repeat_noise() if repeat else noise()