Revert "[dreambooth] fixing the incompatibity in requirements.txt (#3190) (#3378)" (#3481)

This commit is contained in:
NatalieC323
2023-04-06 20:22:52 +08:00
committed by GitHub
parent 891b8e7fac
commit fb8fae6f29
14 changed files with 98 additions and 124 deletions

View File

@@ -6,10 +6,11 @@ except:
import torch.nn.functional as F
from contextlib import contextmanager
from torch.nn import Identity
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config
from ldm.modules.ema import LitEma
@@ -31,7 +32,7 @@ class AutoencoderKL(pl.LightningModule):
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = Identity(**lossconfig.get("params", dict()))
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)

View File

@@ -9,10 +9,9 @@ from copy import deepcopy
from einops import rearrange
from glob import glob
from natsort import natsorted
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.lr_scheduler import LambdaLinearScheduler
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.util import log_txt_as_img, default, ismap
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
__models__ = {
'class_label': EncoderUNetModel,
@@ -87,7 +86,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
print(f"Unexpected Keys: {unexpected}")
def load_diffusion(self):
model = LatentDiffusion(**self.diffusion_config.get('params',dict()))
model = instantiate_from_config(self.diffusion_config)
self.diffusion_model = model.eval()
self.diffusion_model.train = disabled_train
for param in self.diffusion_model.parameters():
@@ -222,7 +221,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
if self.use_scheduler:
scheduler = LambdaLinearScheduler(**self.scheduler_config.get('params',dict()))
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [

View File

@@ -22,7 +22,6 @@ from contextlib import contextmanager, nullcontext
from functools import partial
from einops import rearrange, repeat
from ldm.lr_scheduler import LambdaLinearScheduler
from ldm.models.autoencoder import *
from ldm.models.autoencoder import AutoencoderKL, IdentityFirstStage
from ldm.models.diffusion.ddim import *
@@ -30,10 +29,9 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.modules.diffusionmodules.model import *
from ldm.modules.diffusionmodules.model import Decoder, Encoder, Model
from ldm.modules.diffusionmodules.openaimodel import *
from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d, UNetModel
from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution, normal_kl
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from ldm.modules.ema import LitEma
from ldm.modules.encoders.modules import *
from ldm.util import count_params, default, exists, instantiate_from_config, isimage, ismap, log_txt_as_img, mean_flat
@@ -41,7 +39,6 @@ from omegaconf import ListConfig
from torch.optim.lr_scheduler import LambdaLR
from torchvision.utils import make_grid
from tqdm import tqdm
from ldm.modules.midas.api import MiDaSInference
__conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', 'adm': 'y'}
@@ -693,7 +690,7 @@ class LatentDiffusion(DDPM):
self.make_cond_schedule()
def instantiate_first_stage(self, config):
model = AutoencoderKL(**config.get("params", dict()))
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
@@ -709,7 +706,7 @@ class LatentDiffusion(DDPM):
self.cond_stage_model = None
# self.be_unconditional = True
else:
model = FrozenOpenCLIPEmbedder(**config.get("params", dict()))
model = instantiate_from_config(config)
self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
@@ -717,7 +714,7 @@ class LatentDiffusion(DDPM):
else:
assert config != '__is_first_stage__'
assert config != '__is_unconditional__'
model = FrozenOpenCLIPEmbedder(**config.get("params", dict()))
model = instantiate_from_config(config)
self.cond_stage_model = model
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
@@ -1482,7 +1479,8 @@ class LatentDiffusion(DDPM):
# opt = torch.optim.AdamW(params, lr=lr)
if self.use_scheduler:
scheduler = LambdaLinearScheduler(**self.scheduler_config.get("params", dict()))
assert 'target' in self.scheduler_config
scheduler = instantiate_from_config(self.scheduler_config)
rank_zero_info("Setting up LambdaLR scheduler...")
scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}]
@@ -1504,7 +1502,7 @@ class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key):
super().__init__()
self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
self.diffusion_model = UNetModel(**diff_model_config.get("params", dict()))
self.diffusion_model = instantiate_from_config(diff_model_config)
self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
@@ -1553,7 +1551,7 @@ class LatentUpscaleDiffusion(LatentDiffusion):
self.noise_level_key = noise_level_key
def instantiate_low_stage(self, config):
model = ImageConcatWithNoiseAugmentation(**config.get("params", dict()))
model = instantiate_from_config(config)
self.low_scale_model = model.eval()
self.low_scale_model.train = disabled_train
for param in self.low_scale_model.parameters():
@@ -1935,7 +1933,7 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
super().__init__(concat_keys=concat_keys, *args, **kwargs)
self.depth_model = MiDaSInference(**depth_stage_config.get("params", dict()))
self.depth_model = instantiate_from_config(depth_stage_config)
self.depth_stage_key = concat_keys[0]
@torch.no_grad()
@@ -2008,7 +2006,7 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
self.low_scale_key = low_scale_key
def instantiate_low_stage(self, config):
model = ImageConcatWithNoiseAugmentation(**config.get("params", dict()))
model = instantiate_from_config(config)
self.low_scale_model = model.eval()
self.low_scale_model.train = disabled_train
for param in self.low_scale_model.parameters():