mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
support stable diffusion v2
This commit is contained in:
@@ -4,9 +4,22 @@ import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from typing import Optional, Any
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.modules.attention import LinearAttention
|
||||
try:
|
||||
from lightning.pytorch.utilities import rank_zero_info
|
||||
except:
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
from ldm.modules.attention import MemoryEfficientCrossAttention
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
print("No module 'xformers'. Proceeding without it.")
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
@@ -141,12 +154,6 @@ class ResnetBlock(nn.Module):
|
||||
return x+h
|
||||
|
||||
|
||||
class LinAttnBlock(LinearAttention):
|
||||
"""to match AttnBlock usage"""
|
||||
def __init__(self, in_channels):
|
||||
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
@@ -174,7 +181,6 @@ class AttnBlock(nn.Module):
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
@@ -201,21 +207,100 @@ class AttnBlock(nn.Module):
|
||||
|
||||
return x+h_
|
||||
|
||||
class MemoryEfficientAttnBlock(nn.Module):
|
||||
"""
|
||||
Uses xformers efficient implementation,
|
||||
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
Note: this is a single-head self-attention operation
|
||||
"""
|
||||
#
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla"):
|
||||
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(B, t.shape[1], 1, C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * 1, t.shape[1], C)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(B, 1, out.shape[1], C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B, out.shape[1], C)
|
||||
)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
||||
out = self.proj_out(out)
|
||||
return x+out
|
||||
|
||||
|
||||
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
||||
def forward(self, x, context=None, mask=None):
|
||||
b, c, h, w = x.shape
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
out = super().forward(x, context=context, mask=mask)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
|
||||
return x + out
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
|
||||
attn_type = "vanilla-xformers"
|
||||
rank_zero_info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "vanilla-xformers":
|
||||
rank_zero_info(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||
return MemoryEfficientAttnBlock(in_channels)
|
||||
elif type == "memory-efficient-cross-attn":
|
||||
attn_kwargs["query_dim"] = in_channels
|
||||
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
||||
elif attn_type == "none":
|
||||
return nn.Identity(in_channels)
|
||||
else:
|
||||
return LinAttnBlock(in_channels)
|
||||
raise NotImplementedError()
|
||||
|
||||
class temb_module(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
@@ -233,8 +318,7 @@ class Model(nn.Module):
|
||||
self.use_timestep = use_timestep
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
# self.temb = nn.Module()
|
||||
self.temb = temb_module()
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList([
|
||||
torch.nn.Linear(self.ch,
|
||||
self.temb_ch),
|
||||
@@ -265,8 +349,7 @@ class Model(nn.Module):
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
# down = nn.Module()
|
||||
down = Down_module()
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
@@ -275,8 +358,7 @@ class Model(nn.Module):
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
# self.mid = nn.Module()
|
||||
self.mid = Mid_module()
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
@@ -304,8 +386,7 @@ class Model(nn.Module):
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
# up = nn.Module()
|
||||
up = Up_module()
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
@@ -372,21 +453,6 @@ class Model(nn.Module):
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.weight
|
||||
|
||||
class Down_module(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
class Up_module(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
class Mid_module(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
@@ -426,8 +492,7 @@ class Encoder(nn.Module):
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
# down = nn.Module()
|
||||
down = Down_module()
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
@@ -436,8 +501,7 @@ class Encoder(nn.Module):
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
# self.mid = nn.Module()
|
||||
self.mid = Mid_module()
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
@@ -505,7 +569,7 @@ class Decoder(nn.Module):
|
||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(
|
||||
rank_zero_info("Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
@@ -516,8 +580,7 @@ class Decoder(nn.Module):
|
||||
padding=1)
|
||||
|
||||
# middle
|
||||
# self.mid = nn.Module()
|
||||
self.mid = Mid_module()
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
@@ -542,8 +605,7 @@ class Decoder(nn.Module):
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
# up = nn.Module()
|
||||
up = Up_module()
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
@@ -758,7 +820,7 @@ class Upsampler(nn.Module):
|
||||
assert out_size >= in_size
|
||||
num_blocks = int(np.log2(out_size//in_size))+1
|
||||
factor_up = 1.+ (out_size % in_size)
|
||||
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
||||
rank_zero_info(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
||||
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
|
||||
out_channels=in_channels)
|
||||
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
|
||||
@@ -777,7 +839,7 @@ class Resize(nn.Module):
|
||||
self.with_conv = learned
|
||||
self.mode = mode
|
||||
if self.with_conv:
|
||||
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
||||
rank_zero_info(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
||||
raise NotImplementedError()
|
||||
assert in_channels is not None
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
@@ -793,70 +855,3 @@ class Resize(nn.Module):
|
||||
else:
|
||||
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
|
||||
return x
|
||||
|
||||
class FirstStagePostProcessor(nn.Module):
|
||||
|
||||
def __init__(self, ch_mult:list, in_channels,
|
||||
pretrained_model:nn.Module=None,
|
||||
reshape=False,
|
||||
n_channels=None,
|
||||
dropout=0.,
|
||||
pretrained_config=None):
|
||||
super().__init__()
|
||||
if pretrained_config is None:
|
||||
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
||||
self.pretrained_model = pretrained_model
|
||||
else:
|
||||
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
||||
self.instantiate_pretrained(pretrained_config)
|
||||
|
||||
self.do_reshape = reshape
|
||||
|
||||
if n_channels is None:
|
||||
n_channels = self.pretrained_model.encoder.ch
|
||||
|
||||
self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
|
||||
self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
|
||||
stride=1,padding=1)
|
||||
|
||||
blocks = []
|
||||
downs = []
|
||||
ch_in = n_channels
|
||||
for m in ch_mult:
|
||||
blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
|
||||
ch_in = m * n_channels
|
||||
downs.append(Downsample(ch_in, with_conv=False))
|
||||
|
||||
self.model = nn.ModuleList(blocks)
|
||||
self.downsampler = nn.ModuleList(downs)
|
||||
|
||||
|
||||
def instantiate_pretrained(self, config):
|
||||
model = instantiate_from_config(config)
|
||||
self.pretrained_model = model.eval()
|
||||
# self.pretrained_model.train = False
|
||||
for param in self.pretrained_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_with_pretrained(self,x):
|
||||
c = self.pretrained_model.encode(x)
|
||||
if isinstance(c, DiagonalGaussianDistribution):
|
||||
c = c.mode()
|
||||
return c
|
||||
|
||||
def forward(self,x):
|
||||
z_fs = self.encode_with_pretrained(x)
|
||||
z = self.proj_norm(z_fs)
|
||||
z = self.proj(z)
|
||||
z = nonlinearity(z)
|
||||
|
||||
for submodel, downmodel in zip(self.model,self.downsampler):
|
||||
z = submodel(z,temb=None)
|
||||
z = downmodel(z)
|
||||
|
||||
if self.do_reshape:
|
||||
z = rearrange(z,'b c h w -> b (h w) c')
|
||||
return z
|
||||
|
||||
|
@@ -1,16 +1,13 @@
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils import checkpoint
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
checkpoint,
|
||||
conv_nd,
|
||||
linear,
|
||||
avg_pool_nd,
|
||||
@@ -19,13 +16,11 @@ from ldm.modules.diffusionmodules.util import (
|
||||
timestep_embedding,
|
||||
)
|
||||
from ldm.modules.attention import SpatialTransformer
|
||||
from ldm.util import exists
|
||||
|
||||
|
||||
# dummy replace
|
||||
def convert_module_to_f16(x):
|
||||
# for n,p in x.named_parameter():
|
||||
# print(f"convert module {n} to_f16")
|
||||
# p.data = p.data.half()
|
||||
pass
|
||||
|
||||
def convert_module_to_f32(x):
|
||||
@@ -251,10 +246,9 @@ class ResBlock(TimestepBlock):
|
||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
if self.use_checkpoint:
|
||||
return checkpoint(self._forward, x, emb)
|
||||
else:
|
||||
return self._forward(x, emb)
|
||||
return checkpoint(
|
||||
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
||||
)
|
||||
|
||||
|
||||
def _forward(self, x, emb):
|
||||
@@ -317,11 +311,8 @@ class AttentionBlock(nn.Module):
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_checkpoint:
|
||||
return checkpoint(self._forward, x) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||
#return pt_checkpoint(self._forward, x) # pytorch
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
def _forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
@@ -474,7 +465,10 @@ class UNetModel(nn.Module):
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
from_pretrained: str=None
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None,
|
||||
disable_middle_self_attn=False,
|
||||
use_linear_in_transformer=False,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
@@ -499,7 +493,24 @@ class UNetModel(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
if len(num_res_blocks) != len(channel_mult):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set.")
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
@@ -520,7 +531,13 @@ class UNetModel(nn.Module):
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
elif self.num_classes == "continuous":
|
||||
print("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
@@ -534,7 +551,7 @@ class UNetModel(nn.Module):
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
@@ -556,17 +573,25 @@ class UNetModel(nn.Module):
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, use_checkpoint=use_checkpoint,
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
@@ -618,8 +643,10 @@ class UNetModel(nn.Module):
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
@@ -634,7 +661,7 @@ class UNetModel(nn.Module):
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(num_res_blocks + 1):
|
||||
for i in range(self.num_res_blocks[level] + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(
|
||||
@@ -657,18 +684,26 @@ class UNetModel(nn.Module):
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
)
|
||||
)
|
||||
if level and i == num_res_blocks:
|
||||
if level and i == self.num_res_blocks[level]:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(
|
||||
@@ -699,188 +734,6 @@ class UNetModel(nn.Module):
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
# if use_fp16:
|
||||
# self.convert_to_fp16()
|
||||
from diffusers.modeling_utils import load_state_dict
|
||||
if from_pretrained is not None:
|
||||
state_dict = load_state_dict(from_pretrained)
|
||||
self._load_pretrained_model(state_dict)
|
||||
|
||||
def _input_blocks_mapping(self, input_dict):
|
||||
res_dict = {}
|
||||
for key_, value_ in input_dict.items():
|
||||
id_0 = int(key_[13])
|
||||
if "resnets" in key_:
|
||||
id_1 = int(key_[23])
|
||||
target_id = 3 * id_0 + 1 + id_1
|
||||
post_fix = key_[25:].replace('time_emb_proj', 'emb_layers.1')\
|
||||
.replace('norm1', 'in_layers.0')\
|
||||
.replace('norm2', 'out_layers.0')\
|
||||
.replace('conv1', 'in_layers.2')\
|
||||
.replace('conv2', 'out_layers.3')\
|
||||
.replace('conv_shortcut', 'skip_connection')
|
||||
res_dict["input_blocks." + str(target_id) + '.0.' + post_fix] = value_
|
||||
elif "attentions" in key_:
|
||||
id_1 = int(key_[26])
|
||||
target_id = 3 * id_0 + 1 + id_1
|
||||
post_fix = key_[28:]
|
||||
res_dict["input_blocks." + str(target_id) + '.1.' + post_fix] = value_
|
||||
elif "downsamplers" in key_:
|
||||
post_fix = key_[35:]
|
||||
target_id = 3 * (id_0 + 1)
|
||||
res_dict["input_blocks." + str(target_id) + '.0.op.' + post_fix] = value_
|
||||
return res_dict
|
||||
|
||||
|
||||
def _mid_blocks_mapping(self, mid_dict):
|
||||
res_dict = {}
|
||||
for key_, value_ in mid_dict.items():
|
||||
if "resnets" in key_:
|
||||
temp_key_ =key_.replace('time_emb_proj', 'emb_layers.1') \
|
||||
.replace('norm1', 'in_layers.0') \
|
||||
.replace('norm2', 'out_layers.0') \
|
||||
.replace('conv1', 'in_layers.2') \
|
||||
.replace('conv2', 'out_layers.3') \
|
||||
.replace('conv_shortcut', 'skip_connection')\
|
||||
.replace('middle_block.resnets.0', 'middle_block.0')\
|
||||
.replace('middle_block.resnets.1', 'middle_block.2')
|
||||
res_dict[temp_key_] = value_
|
||||
elif "attentions" in key_:
|
||||
res_dict[key_.replace('attentions.0', '1')] = value_
|
||||
return res_dict
|
||||
|
||||
def _other_blocks_mapping(self, other_dict):
|
||||
res_dict = {}
|
||||
for key_, value_ in other_dict.items():
|
||||
tmp_key = key_.replace('conv_in', 'input_blocks.0.0')\
|
||||
.replace('time_embedding.linear_1', 'time_embed.0')\
|
||||
.replace('time_embedding.linear_2', 'time_embed.2')\
|
||||
.replace('conv_norm_out', 'out.0')\
|
||||
.replace('conv_out', 'out.2')
|
||||
res_dict[tmp_key] = value_
|
||||
return res_dict
|
||||
|
||||
|
||||
def _output_blocks_mapping(self, output_dict):
|
||||
res_dict = {}
|
||||
for key_, value_ in output_dict.items():
|
||||
id_0 = int(key_[14])
|
||||
if "resnets" in key_:
|
||||
id_1 = int(key_[24])
|
||||
target_id = 3 * id_0 + id_1
|
||||
post_fix = key_[26:].replace('time_emb_proj', 'emb_layers.1') \
|
||||
.replace('norm1', 'in_layers.0') \
|
||||
.replace('norm2', 'out_layers.0') \
|
||||
.replace('conv1', 'in_layers.2') \
|
||||
.replace('conv2', 'out_layers.3') \
|
||||
.replace('conv_shortcut', 'skip_connection')
|
||||
res_dict["output_blocks." + str(target_id) + '.0.' + post_fix] = value_
|
||||
elif "attentions" in key_:
|
||||
id_1 = int(key_[27])
|
||||
target_id = 3 * id_0 + id_1
|
||||
post_fix = key_[29:]
|
||||
res_dict["output_blocks." + str(target_id) + '.1.' + post_fix] = value_
|
||||
elif "upsamplers" in key_:
|
||||
post_fix = key_[34:]
|
||||
target_id = 3 * (id_0 + 1) - 1
|
||||
mid_str = '.2.conv.' if target_id != 2 else '.1.conv.'
|
||||
res_dict["output_blocks." + str(target_id) + mid_str + post_fix] = value_
|
||||
return res_dict
|
||||
|
||||
def _state_key_mapping(self, state_dict: dict):
|
||||
import re
|
||||
res_dict = {}
|
||||
input_dict = {}
|
||||
mid_dict = {}
|
||||
output_dict = {}
|
||||
other_dict = {}
|
||||
for key_, value_ in state_dict.items():
|
||||
if "down_blocks" in key_:
|
||||
input_dict[key_.replace('down_blocks', 'input_blocks')] = value_
|
||||
elif "up_blocks" in key_:
|
||||
output_dict[key_.replace('up_blocks', 'output_blocks')] = value_
|
||||
elif "mid_block" in key_:
|
||||
mid_dict[key_.replace('mid_block', 'middle_block')] = value_
|
||||
else:
|
||||
other_dict[key_] = value_
|
||||
|
||||
input_dict = self._input_blocks_mapping(input_dict)
|
||||
output_dict = self._output_blocks_mapping(output_dict)
|
||||
mid_dict = self._mid_blocks_mapping(mid_dict)
|
||||
other_dict = self._other_blocks_mapping(other_dict)
|
||||
# key_list = state_dict.keys()
|
||||
# key_str = " ".join(key_list)
|
||||
|
||||
# for key_, val_ in state_dict.items():
|
||||
# key_ = key_.replace("down_blocks", "input_blocks")\
|
||||
# .replace("up_blocks", 'output_blocks')
|
||||
# res_dict[key_] = val_
|
||||
res_dict.update(input_dict)
|
||||
res_dict.update(output_dict)
|
||||
res_dict.update(mid_dict)
|
||||
res_dict.update(other_dict)
|
||||
|
||||
return res_dict
|
||||
|
||||
def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False):
|
||||
state_dict = self._state_key_mapping(state_dict)
|
||||
model_state_dict = self.state_dict()
|
||||
loaded_keys = [k for k in state_dict.keys()]
|
||||
expected_keys = list(model_state_dict.keys())
|
||||
original_loaded_keys = loaded_keys
|
||||
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
||||
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
||||
|
||||
def _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
ignore_mismatched_sizes,
|
||||
):
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
for checkpoint_key in loaded_keys:
|
||||
model_key = checkpoint_key
|
||||
|
||||
if (
|
||||
model_key in model_state_dict
|
||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||
):
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
return mismatched_keys
|
||||
if state_dict is not None:
|
||||
# Whole checkpoint
|
||||
mismatched_keys = _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
original_loaded_keys,
|
||||
ignore_mismatched_sizes,
|
||||
)
|
||||
error_msgs = self._load_state_dict_into_model(state_dict)
|
||||
return missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
||||
|
||||
def _load_state_dict_into_model(self, state_dict):
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
state_dict = state_dict.copy()
|
||||
error_msgs = []
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: torch.nn.Module, prefix=""):
|
||||
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
||||
module._load_from_state_dict(*args)
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
load(self)
|
||||
|
||||
return error_msgs
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
@@ -912,10 +765,11 @@ class UNetModel(nn.Module):
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
t_emb = t_emb.type(self.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape == (x.shape[0],)
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
@@ -926,227 +780,8 @@ class UNetModel(nn.Module):
|
||||
for module in self.output_blocks:
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
h = h.type(self.dtype)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
else:
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class EncoderUNetModel(nn.Module):
|
||||
"""
|
||||
The half UNet model with attention and timestep embedding.
|
||||
For usage, see UNet.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
pool="adaptive",
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
self.pool = pool
|
||||
if pool == "adaptive":
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
||||
nn.Flatten(),
|
||||
)
|
||||
elif pool == "attention":
|
||||
assert num_head_channels != -1
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
AttentionPool2d(
|
||||
(image_size // ds), ch, num_head_channels, out_channels
|
||||
),
|
||||
)
|
||||
elif pool == "spatial":
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
elif pool == "spatial_v2":
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
normalization(2048),
|
||||
nn.SiLU(),
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unexpected {pool} pooling")
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
Convert the torso of the model to float16.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f16)
|
||||
self.middle_block.apply(convert_module_to_f16)
|
||||
|
||||
def convert_to_fp32(self):
|
||||
"""
|
||||
Convert the torso of the model to float32.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f32)
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x K] Tensor of outputs.
|
||||
"""
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
results = []
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
if self.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = self.middle_block(h, emb)
|
||||
if self.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = th.cat(results, axis=-1)
|
||||
return self.out(h)
|
||||
else:
|
||||
h = h.type(self.dtype)
|
||||
return self.out(h)
|
||||
|
||||
|
@@ -0,0 +1,81 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
|
||||
from ldm.util import default
|
||||
|
||||
|
||||
class AbstractLowScaleModel(nn.Module):
|
||||
# for concatenating a downsampled image to the latent representation
|
||||
def __init__(self, noise_schedule_config=None):
|
||||
super(AbstractLowScaleModel, self).__init__()
|
||||
if noise_schedule_config is not None:
|
||||
self.register_schedule(**noise_schedule_config)
|
||||
|
||||
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
||||
cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer('betas', to_torch(betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||
|
||||
def forward(self, x):
|
||||
return x, None
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class SimpleImageConcat(AbstractLowScaleModel):
|
||||
# no noise level conditioning
|
||||
def __init__(self):
|
||||
super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
|
||||
self.max_noise_level = 0
|
||||
|
||||
def forward(self, x):
|
||||
# fix to constant noise level
|
||||
return x, torch.zeros(x.shape[0], device=x.device).long()
|
||||
|
||||
|
||||
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
|
||||
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
|
||||
super().__init__(noise_schedule_config=noise_schedule_config)
|
||||
self.max_noise_level = max_noise_level
|
||||
|
||||
def forward(self, x, noise_level=None):
|
||||
if noise_level is None:
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
else:
|
||||
assert isinstance(noise_level, torch.Tensor)
|
||||
z = self.q_sample(x, noise_level)
|
||||
return z, noise_level
|
||||
|
||||
|
||||
|
@@ -122,7 +122,9 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
|
||||
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
||||
"dtype": torch.get_autocast_gpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled()}
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
@@ -130,7 +132,8 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
with torch.enable_grad():
|
||||
with torch.enable_grad(), \
|
||||
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
@@ -148,7 +151,7 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
return (None, None) + input_grads
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_fp16=True):
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
@@ -168,10 +171,7 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
if use_fp16:
|
||||
return embedding.half()
|
||||
else:
|
||||
return embedding
|
||||
return embedding
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
@@ -199,16 +199,14 @@ def mean_flat(tensor):
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels, precision=16):
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
if precision == 16:
|
||||
return GroupNorm16(16, channels)
|
||||
else:
|
||||
return GroupNorm32(32, channels)
|
||||
return nn.GroupNorm(16, channels)
|
||||
# return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
@@ -216,9 +214,6 @@ class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
class GroupNorm16(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.half()).type(x.dtype)
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
|
Reference in New Issue
Block a user