support stable diffusion v2

This commit is contained in:
Fazzie
2022-12-12 17:35:23 +08:00
parent e99edfcb51
commit cea4292ae5
51 changed files with 5597 additions and 3302 deletions

View File

@@ -122,7 +122,9 @@ class CheckpointFunction(torch.autograd.Function):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()}
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@@ -130,7 +132,8 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad():
with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
@@ -148,7 +151,7 @@ class CheckpointFunction(torch.autograd.Function):
return (None, None) + input_grads
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_fp16=True):
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
@@ -168,10 +171,7 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
if use_fp16:
return embedding.half()
else:
return embedding
return embedding
def zero_module(module):
@@ -199,16 +199,14 @@ def mean_flat(tensor):
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels, precision=16):
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
if precision == 16:
return GroupNorm16(16, channels)
else:
return GroupNorm32(32, channels)
return nn.GroupNorm(16, channels)
# return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
@@ -216,9 +214,6 @@ class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm16(nn.GroupNorm):
def forward(self, x):
return super().forward(x.half()).type(x.dtype)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):