[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -17,6 +17,7 @@ from ldm.modules.attention import MemoryEfficientCrossAttention
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
@@ -39,7 +40,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
@@ -54,7 +55,6 @@ def Normalize(in_channels, num_groups=32):
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
@@ -69,7 +69,6 @@ class Upsample(nn.Module):
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
@@ -88,7 +87,6 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
@@ -133,7 +131,6 @@ class ResnetBlock(nn.Module):
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
@@ -154,16 +151,16 @@ class AttnBlock(nn.Module):
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
@@ -173,9 +170,9 @@ class AttnBlock(nn.Module):
class MemoryEfficientAttnBlock(nn.Module):
"""
Uses xformers efficient implementation,
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
Note: this is a single-head self-attention operation
Uses xformers efficient implementation,
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
Note: this is a single-head self-attention operation
"""
#
@@ -199,34 +196,41 @@ class MemoryEfficientAttnBlock(nn.Module):
# compute attention
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(B, t.shape[1], 1, C).permute(0, 2, 1, 3).reshape(B * 1, t.shape[1], C).
contiguous(),
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = (out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C))
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C)
out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
out = self.proj_out(out)
return x + out
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def forward(self, x, context=None, mask=None):
b, c, h, w = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
x = rearrange(x, "b c h w -> b (h w) c")
out = super().forward(x, context=context, mask=mask)
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
return x + out
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear",
"none"], f'attn_type {attn_type} unknown'
assert attn_type in [
"vanilla",
"vanilla-xformers",
"memory-efficient-cross-attn",
"linear",
"none",
], f"attn_type {attn_type} unknown"
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
attn_type = "vanilla-xformers"
if attn_type == "vanilla":
@@ -245,21 +249,22 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
class Model(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
use_timestep=True,
use_linear_attn=False,
attn_type="vanilla"):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
use_timestep=True,
use_linear_attn=False,
attn_type="vanilla",
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
@@ -274,10 +279,12 @@ class Model(nn.Module):
if self.use_timestep:
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, self.temb_ch),
])
self.temb.dense = nn.ModuleList(
[
torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, self.temb_ch),
]
)
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
@@ -292,10 +299,10 @@ class Model(nn.Module):
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
ResnetBlock(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
@@ -309,15 +316,13 @@ class Model(nn.Module):
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
# upsampling
self.up = nn.ModuleList()
@@ -330,10 +335,13 @@ class Model(nn.Module):
if i_block == self.num_res_blocks:
skip_in = ch * in_ch_mult[i_level]
block.append(
ResnetBlock(in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
ResnetBlock(
in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
@@ -343,14 +351,14 @@ class Model(nn.Module):
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, x, t=None, context=None):
#assert x.shape[2] == x.shape[3] == self.resolution
# assert x.shape[2] == x.shape[3] == self.resolution
if context is not None:
# assume aligned context, cat along channel axis
x = torch.cat((x, context), dim=1)
@@ -401,23 +409,24 @@ class Model(nn.Module):
class Encoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type="vanilla",
**ignore_kwargs):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type="vanilla",
**ignore_kwargs,
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
@@ -442,10 +451,10 @@ class Encoder(nn.Module):
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
ResnetBlock(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
@@ -459,23 +468,19 @@ class Encoder(nn.Module):
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
self.conv_out = torch.nn.Conv2d(
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
# timestep embedding
@@ -506,24 +511,25 @@ class Encoder(nn.Module):
class Decoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type="vanilla",
**ignorekwargs):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type="vanilla",
**ignorekwargs,
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
@@ -537,9 +543,9 @@ class Decoder(nn.Module):
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
(1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
rank_zero_info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
@@ -548,15 +554,13 @@ class Decoder(nn.Module):
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
# upsampling
self.up = nn.ModuleList()
@@ -566,10 +570,10 @@ class Decoder(nn.Module):
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
ResnetBlock(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
@@ -579,14 +583,14 @@ class Decoder(nn.Module):
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z):
#assert z.shape[1:] == self.z_shape[1:]
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
@@ -622,17 +626,18 @@ class Decoder(nn.Module):
class SimpleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs):
super().__init__()
self.model = nn.ModuleList([
nn.Conv2d(in_channels, in_channels, 1),
ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
nn.Conv2d(2 * in_channels, in_channels, 1),
Upsample(in_channels, with_conv=True)
])
self.model = nn.ModuleList(
[
nn.Conv2d(in_channels, in_channels, 1),
ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
nn.Conv2d(2 * in_channels, in_channels, 1),
Upsample(in_channels, with_conv=True),
]
)
# end
self.norm_out = Normalize(in_channels)
self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
@@ -651,7 +656,6 @@ class SimpleDecoder(nn.Module):
class UpsampleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0):
super().__init__()
# upsampling
@@ -659,7 +663,7 @@ class UpsampleDecoder(nn.Module):
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = in_channels
curr_res = resolution // 2**(self.num_resolutions - 1)
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.res_blocks = nn.ModuleList()
self.upsample_blocks = nn.ModuleList()
for i_level in range(self.num_resolutions):
@@ -667,10 +671,10 @@ class UpsampleDecoder(nn.Module):
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
res_block.append(
ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
ResnetBlock(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
)
)
block_in = block_out
self.res_blocks.append(nn.ModuleList(res_block))
if i_level != self.num_resolutions - 1:
@@ -696,21 +700,24 @@ class UpsampleDecoder(nn.Module):
class LatentRescaler(nn.Module):
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
super().__init__()
# residual block, interpolate, residual block
self.factor = factor
self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
self.res_block1 = nn.ModuleList([
ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
for _ in range(depth)
])
self.res_block1 = nn.ModuleList(
[
ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
for _ in range(depth)
]
)
self.attn = AttnBlock(mid_channels)
self.res_block2 = nn.ModuleList([
ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
for _ in range(depth)
])
self.res_block2 = nn.ModuleList(
[
ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
for _ in range(depth)
]
)
self.conv_out = nn.Conv2d(
mid_channels,
@@ -722,9 +729,9 @@ class LatentRescaler(nn.Module):
x = self.conv_in(x)
for block in self.res_block1:
x = block(x, None)
x = torch.nn.functional.interpolate(x,
size=(int(round(x.shape[2] * self.factor)),
int(round(x.shape[3] * self.factor))))
x = torch.nn.functional.interpolate(
x, size=(int(round(x.shape[2] * self.factor)), int(round(x.shape[3] * self.factor)))
)
x = self.attn(x)
for block in self.res_block2:
x = block(x, None)
@@ -733,37 +740,42 @@ class LatentRescaler(nn.Module):
class MergedRescaleEncoder(nn.Module):
def __init__(self,
in_channels,
ch,
resolution,
out_ch,
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
ch_mult=(1, 2, 4, 8),
rescale_factor=1.0,
rescale_module_depth=1):
def __init__(
self,
in_channels,
ch,
resolution,
out_ch,
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
ch_mult=(1, 2, 4, 8),
rescale_factor=1.0,
rescale_module_depth=1,
):
super().__init__()
intermediate_chn = ch * ch_mult[-1]
self.encoder = Encoder(in_channels=in_channels,
num_res_blocks=num_res_blocks,
ch=ch,
ch_mult=ch_mult,
z_channels=intermediate_chn,
double_z=False,
resolution=resolution,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
out_ch=None)
self.rescaler = LatentRescaler(factor=rescale_factor,
in_channels=intermediate_chn,
mid_channels=intermediate_chn,
out_channels=out_ch,
depth=rescale_module_depth)
self.encoder = Encoder(
in_channels=in_channels,
num_res_blocks=num_res_blocks,
ch=ch,
ch_mult=ch_mult,
z_channels=intermediate_chn,
double_z=False,
resolution=resolution,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
out_ch=None,
)
self.rescaler = LatentRescaler(
factor=rescale_factor,
in_channels=intermediate_chn,
mid_channels=intermediate_chn,
out_channels=out_ch,
depth=rescale_module_depth,
)
def forward(self, x):
x = self.encoder(x)
@@ -772,36 +784,41 @@ class MergedRescaleEncoder(nn.Module):
class MergedRescaleDecoder(nn.Module):
def __init__(self,
z_channels,
out_ch,
resolution,
num_res_blocks,
attn_resolutions,
ch,
ch_mult=(1, 2, 4, 8),
dropout=0.0,
resamp_with_conv=True,
rescale_factor=1.0,
rescale_module_depth=1):
def __init__(
self,
z_channels,
out_ch,
resolution,
num_res_blocks,
attn_resolutions,
ch,
ch_mult=(1, 2, 4, 8),
dropout=0.0,
resamp_with_conv=True,
rescale_factor=1.0,
rescale_module_depth=1,
):
super().__init__()
tmp_chn = z_channels * ch_mult[-1]
self.decoder = Decoder(out_ch=out_ch,
z_channels=tmp_chn,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
in_channels=None,
num_res_blocks=num_res_blocks,
ch_mult=ch_mult,
resolution=resolution,
ch=ch)
self.rescaler = LatentRescaler(factor=rescale_factor,
in_channels=z_channels,
mid_channels=tmp_chn,
out_channels=tmp_chn,
depth=rescale_module_depth)
self.decoder = Decoder(
out_ch=out_ch,
z_channels=tmp_chn,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
in_channels=None,
num_res_blocks=num_res_blocks,
ch_mult=ch_mult,
resolution=resolution,
ch=ch,
)
self.rescaler = LatentRescaler(
factor=rescale_factor,
in_channels=z_channels,
mid_channels=tmp_chn,
out_channels=tmp_chn,
depth=rescale_module_depth,
)
def forward(self, x):
x = self.rescaler(x)
@@ -810,27 +827,27 @@ class MergedRescaleDecoder(nn.Module):
class Upsampler(nn.Module):
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
super().__init__()
assert out_size >= in_size
num_blocks = int(np.log2(out_size // in_size)) + 1
factor_up = 1. + (out_size % in_size)
factor_up = 1.0 + (out_size % in_size)
rank_zero_info(
f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
)
self.rescaler = LatentRescaler(factor=factor_up,
in_channels=in_channels,
mid_channels=2 * in_channels,
out_channels=in_channels)
self.decoder = Decoder(out_ch=out_channels,
resolution=out_size,
z_channels=in_channels,
num_res_blocks=2,
attn_resolutions=[],
in_channels=None,
ch=in_channels,
ch_mult=[ch_mult for _ in range(num_blocks)])
self.rescaler = LatentRescaler(
factor=factor_up, in_channels=in_channels, mid_channels=2 * in_channels, out_channels=in_channels
)
self.decoder = Decoder(
out_ch=out_channels,
resolution=out_size,
z_channels=in_channels,
num_res_blocks=2,
attn_resolutions=[],
in_channels=None,
ch=in_channels,
ch_mult=[ch_mult for _ in range(num_blocks)],
)
def forward(self, x):
x = self.rescaler(x)
@@ -839,14 +856,14 @@ class Upsampler(nn.Module):
class Resize(nn.Module):
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
super().__init__()
self.with_conv = learned
self.mode = mode
if self.with_conv:
rank_zero_info(
f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
)
raise NotImplementedError()
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves

View File

@@ -1,21 +1,20 @@
from abc import abstractmethod
import math
from abc import abstractmethod
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.util import (
avg_pool_nd,
checkpoint,
conv_nd,
linear,
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
zero_module,
)
from ldm.modules.attention import SpatialTransformer
from ldm.util import exists
@@ -23,6 +22,7 @@ from ldm.util import exists
def convert_module_to_f16(x):
pass
def convert_module_to_f32(x):
pass
@@ -41,7 +41,7 @@ class AttentionPool2d(nn.Module):
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
@@ -108,25 +108,25 @@ class Upsample(nn.Module):
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class TransposedUpsample(nn.Module):
'Learned 2x upsampling without padding'
"Learned 2x upsampling without padding"
def __init__(self, channels, out_channels=None, ks=5):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2)
def forward(self,x):
def forward(self, x):
return self.up(x)
@@ -139,7 +139,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -147,9 +147,7 @@ class Downsample(nn.Module):
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
)
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
@@ -225,17 +223,13 @@ class ResBlock(TimestepBlock):
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
@@ -246,10 +240,7 @@ class ResBlock(TimestepBlock):
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint)
def _forward(self, x, emb):
if self.updown:
@@ -311,8 +302,10 @@ class AttentionBlock(nn.Module):
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
#return pt_checkpoint(self._forward, x) # pytorch
return checkpoint(
self._forward, (x,), self.parameters(), True
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
# return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
b, c, *spatial = x.shape
@@ -339,7 +332,7 @@ def count_flops_attn(model, _x, y):
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
@@ -363,9 +356,7 @@ class QKVAttentionLegacy(nn.Module):
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@@ -460,10 +451,10 @@ class UNetModel(nn.Module):
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
@@ -472,11 +463,16 @@ class UNetModel(nn.Module):
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
assert (
context_dim is not None
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
assert (
use_spatial_transformer
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
@@ -484,10 +480,10 @@ class UNetModel(nn.Module):
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set"
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
assert num_heads != -1, "Either num_heads or num_head_channels has to be set"
self.image_size = image_size
self.in_channels = in_channels
@@ -497,19 +493,25 @@ class UNetModel(nn.Module):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
raise ValueError(
"provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult"
)
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")
assert all(
map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))
)
print(
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set."
)
self.attention_resolutions = attention_resolutions
self.dropout = dropout
@@ -540,11 +542,7 @@ class UNetModel(nn.Module):
raise ValueError()
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
@@ -571,7 +569,7 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
@@ -586,10 +584,17 @@ class UNetModel(nn.Module):
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
@@ -610,9 +615,7 @@ class UNetModel(nn.Module):
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
)
ch = out_ch
@@ -626,7 +629,7 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
@@ -643,11 +646,18 @@ class UNetModel(nn.Module):
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
)
if not use_spatial_transformer
else SpatialTransformer( # always uses a self-attn
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
),
ResBlock(
ch,
time_embed_dim,
@@ -682,7 +692,7 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
@@ -697,10 +707,17 @@ class UNetModel(nn.Module):
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
)
)
if level and i == self.num_res_blocks[level]:
@@ -730,10 +747,10 @@ class UNetModel(nn.Module):
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
def convert_to_fp16(self):
"""
@@ -751,7 +768,7 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.

View File

@@ -1,8 +1,8 @@
import torch
import torch.nn as nn
import numpy as np
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
from ldm.util import default
@@ -14,37 +14,41 @@ class AbstractLowScaleModel(nn.Module):
if noise_schedule_config is not None:
self.register_schedule(**noise_schedule_config)
def register_schedule(self, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
def register_schedule(
self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):
betas = make_beta_schedule(
beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s
)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
timesteps, = betas.shape
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
assert alphas_cumprod.shape[0] == self.num_timesteps, "alphas have to be defined for each timestep"
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)))
self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)))
self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)))
self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def forward(self, x):
return x, None
@@ -76,6 +80,3 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level)
return z, noise_level

View File

@@ -8,7 +8,6 @@
# thanks!
import math
import os
import numpy as np
import torch
@@ -19,10 +18,10 @@ from ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
betas = (torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64)**2)
betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2
elif schedule == "cosine":
timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s)
timesteps = torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
@@ -32,18 +31,18 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
if ddim_discr_method == 'uniform':
if ddim_discr_method == "uniform":
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps))**2).astype(int)
elif ddim_discr_method == "quad":
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2).astype(int)
else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
@@ -51,7 +50,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1
if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}')
print(f"Selected timesteps for ddim sampler: {steps_out}")
return steps_out
@@ -63,9 +62,11 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
print(f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
print(f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}")
print(
f"For the chosen value of eta, which is {eta}, "
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
)
return sigmas, alphas, alphas_prev
@@ -106,6 +107,7 @@ def checkpoint(func, inputs, params, flag):
"""
if flag:
from torch.utils.checkpoint import checkpoint as torch_checkpoint
return torch_checkpoint(func, *inputs)
# args = tuple(inputs) + tuple(params)
# return CheckpointFunction.apply(func, len(inputs), *args)
@@ -114,7 +116,6 @@ def checkpoint(func, inputs, params, flag):
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
@@ -123,7 +124,7 @@ class CheckpointFunction(torch.autograd.Function):
ctx.gpu_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()
"cache_enabled": torch.is_autocast_cache_enabled(),
}
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
@@ -132,8 +133,7 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
@@ -162,14 +162,15 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) /
half).to(device=timesteps.device)
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
embedding = repeat(timesteps, "b -> b d", d=dim)
return embedding
@@ -210,13 +211,11 @@ def normalization(channels):
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
@@ -255,7 +254,6 @@ def avg_pool_nd(dims, *args, **kwargs):
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
@@ -264,7 +262,7 @@ class HybridConditioner(nn.Module):
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn)
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
def noise_like(shape, device, repeat=False):