From 33f3023e19f0edfc997788d2da04d02f33671901 Mon Sep 17 00:00:00 2001 From: 1SAA Date: Fri, 6 Jan 2023 18:37:18 +0800 Subject: [PATCH] [hotfix] fix implement error in diffusers --- colossalai/tensor/param_op_hook.py | 18 ++++++++ .../ldm/modules/diffusionmodules/util.py | 44 ++++++++++--------- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index 7c73bc220..ed705da0e 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -141,7 +141,25 @@ def _is_grad_tensor(obj) -> bool: return False +def _has_grad_tensor(obj) -> bool: + if isinstance(obj, tuple) or isinstance(obj, list): + for x in obj: + if _has_grad_tensor(x): + return True + return False + elif isinstance(obj, dict): + for x in obj.values(): + if _has_grad_tensor(x): + return True + return False + else: + return _is_grad_tensor(obj) + + def _get_grad_args(*args): + # if there is no grad tensors, do nothing + if not _has_grad_tensor(args): + return args, None # returns the identical args if there is a grad tensor for obj in args: if _is_grad_tensor(obj): diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/util.py b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py index e0621032d..36b4a171b 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/util.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py @@ -7,27 +7,22 @@ # # thanks! - -import os import math +import os + +import numpy as np import torch import torch.nn as nn -import numpy as np from einops import repeat - from ldm.util import instantiate_from_config def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if schedule == "linear": - betas = ( - torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 - ) + betas = (torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64)**2) elif schedule == "cosine": - timesteps = ( - torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s - ) + timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s) alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] @@ -37,7 +32,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, elif schedule == "sqrt_linear": betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) elif schedule == "sqrt": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5 else: raise ValueError(f"schedule '{schedule}' unknown.") return betas.numpy() @@ -48,7 +43,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep c = num_ddpm_timesteps // num_ddim_timesteps ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) elif ddim_discr_method == 'quad': - ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps))**2).astype(int) else: raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') @@ -110,21 +105,26 @@ def checkpoint(func, inputs, params, flag): :param flag: if False, disable gradient checkpointing. """ if flag: - args = tuple(inputs) + tuple(params) - return CheckpointFunction.apply(func, len(inputs), *args) + from torch.utils.checkpoint import checkpoint as torch_checkpoint + return torch_checkpoint(func, *inputs) + # args = tuple(inputs) + tuple(params) + # return CheckpointFunction.apply(func, len(inputs), *args) else: return func(*inputs) class CheckpointFunction(torch.autograd.Function): + @staticmethod def forward(ctx, run_function, length, *args): ctx.run_function = run_function ctx.input_tensors = list(args[:length]) ctx.input_params = list(args[length:]) - ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), - "dtype": torch.get_autocast_gpu_dtype(), - "cache_enabled": torch.is_autocast_cache_enabled()} + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled() + } with torch.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) return output_tensors @@ -162,9 +162,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ if not repeat_only: half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=timesteps.device) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / + half).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: @@ -211,14 +210,17 @@ def normalization(channels): # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. class SiLU(nn.Module): + def forward(self, x): return x * torch.sigmoid(x) class GroupNorm32(nn.GroupNorm): + def forward(self, x): return super().forward(x.float()).type(x.dtype) + def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. @@ -268,4 +270,4 @@ class HybridConditioner(nn.Module): def noise_like(shape, device, repeat=False): repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() \ No newline at end of file + return repeat_noise() if repeat else noise()