mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
This commit is contained in:
@@ -6,10 +6,11 @@ except:
|
||||
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
from torch.nn import Identity
|
||||
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.modules.ema import LitEma
|
||||
|
||||
|
||||
@@ -31,7 +32,7 @@ class AutoencoderKL(pl.LightningModule):
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = Identity(**lossconfig.get("params", dict()))
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
|
@@ -9,10 +9,9 @@ from copy import deepcopy
|
||||
from einops import rearrange
|
||||
from glob import glob
|
||||
from natsort import natsorted
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ldm.lr_scheduler import LambdaLinearScheduler
|
||||
|
||||
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
|
||||
from ldm.util import log_txt_as_img, default, ismap
|
||||
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
|
||||
|
||||
__models__ = {
|
||||
'class_label': EncoderUNetModel,
|
||||
@@ -87,7 +86,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def load_diffusion(self):
|
||||
model = LatentDiffusion(**self.diffusion_config.get('params',dict()))
|
||||
model = instantiate_from_config(self.diffusion_config)
|
||||
self.diffusion_model = model.eval()
|
||||
self.diffusion_model.train = disabled_train
|
||||
for param in self.diffusion_model.parameters():
|
||||
@@ -222,7 +221,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
|
||||
|
||||
if self.use_scheduler:
|
||||
scheduler = LambdaLinearScheduler(**self.scheduler_config.get('params',dict()))
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
|
@@ -22,7 +22,6 @@ from contextlib import contextmanager, nullcontext
|
||||
from functools import partial
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from ldm.lr_scheduler import LambdaLinearScheduler
|
||||
from ldm.models.autoencoder import *
|
||||
from ldm.models.autoencoder import AutoencoderKL, IdentityFirstStage
|
||||
from ldm.models.diffusion.ddim import *
|
||||
@@ -30,10 +29,9 @@ from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.modules.diffusionmodules.model import *
|
||||
from ldm.modules.diffusionmodules.model import Decoder, Encoder, Model
|
||||
from ldm.modules.diffusionmodules.openaimodel import *
|
||||
from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d, UNetModel
|
||||
from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d
|
||||
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution, normal_kl
|
||||
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from ldm.modules.ema import LitEma
|
||||
from ldm.modules.encoders.modules import *
|
||||
from ldm.util import count_params, default, exists, instantiate_from_config, isimage, ismap, log_txt_as_img, mean_flat
|
||||
@@ -41,7 +39,6 @@ from omegaconf import ListConfig
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torchvision.utils import make_grid
|
||||
from tqdm import tqdm
|
||||
from ldm.modules.midas.api import MiDaSInference
|
||||
|
||||
__conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', 'adm': 'y'}
|
||||
|
||||
@@ -693,7 +690,7 @@ class LatentDiffusion(DDPM):
|
||||
self.make_cond_schedule()
|
||||
|
||||
def instantiate_first_stage(self, config):
|
||||
model = AutoencoderKL(**config.get("params", dict()))
|
||||
model = instantiate_from_config(config)
|
||||
self.first_stage_model = model.eval()
|
||||
self.first_stage_model.train = disabled_train
|
||||
for param in self.first_stage_model.parameters():
|
||||
@@ -709,7 +706,7 @@ class LatentDiffusion(DDPM):
|
||||
self.cond_stage_model = None
|
||||
# self.be_unconditional = True
|
||||
else:
|
||||
model = FrozenOpenCLIPEmbedder(**config.get("params", dict()))
|
||||
model = instantiate_from_config(config)
|
||||
self.cond_stage_model = model.eval()
|
||||
self.cond_stage_model.train = disabled_train
|
||||
for param in self.cond_stage_model.parameters():
|
||||
@@ -717,7 +714,7 @@ class LatentDiffusion(DDPM):
|
||||
else:
|
||||
assert config != '__is_first_stage__'
|
||||
assert config != '__is_unconditional__'
|
||||
model = FrozenOpenCLIPEmbedder(**config.get("params", dict()))
|
||||
model = instantiate_from_config(config)
|
||||
self.cond_stage_model = model
|
||||
|
||||
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
|
||||
@@ -1482,7 +1479,8 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
# opt = torch.optim.AdamW(params, lr=lr)
|
||||
if self.use_scheduler:
|
||||
scheduler = LambdaLinearScheduler(**self.scheduler_config.get("params", dict()))
|
||||
assert 'target' in self.scheduler_config
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
rank_zero_info("Setting up LambdaLR scheduler...")
|
||||
scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}]
|
||||
@@ -1504,7 +1502,7 @@ class DiffusionWrapper(pl.LightningModule):
|
||||
def __init__(self, diff_model_config, conditioning_key):
|
||||
super().__init__()
|
||||
self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
|
||||
self.diffusion_model = UNetModel(**diff_model_config.get("params", dict()))
|
||||
self.diffusion_model = instantiate_from_config(diff_model_config)
|
||||
self.conditioning_key = conditioning_key
|
||||
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
|
||||
|
||||
@@ -1553,7 +1551,7 @@ class LatentUpscaleDiffusion(LatentDiffusion):
|
||||
self.noise_level_key = noise_level_key
|
||||
|
||||
def instantiate_low_stage(self, config):
|
||||
model = ImageConcatWithNoiseAugmentation(**config.get("params", dict()))
|
||||
model = instantiate_from_config(config)
|
||||
self.low_scale_model = model.eval()
|
||||
self.low_scale_model.train = disabled_train
|
||||
for param in self.low_scale_model.parameters():
|
||||
@@ -1935,7 +1933,7 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
|
||||
|
||||
def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
|
||||
super().__init__(concat_keys=concat_keys, *args, **kwargs)
|
||||
self.depth_model = MiDaSInference(**depth_stage_config.get("params", dict()))
|
||||
self.depth_model = instantiate_from_config(depth_stage_config)
|
||||
self.depth_stage_key = concat_keys[0]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -2008,7 +2006,7 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
||||
self.low_scale_key = low_scale_key
|
||||
|
||||
def instantiate_low_stage(self, config):
|
||||
model = ImageConcatWithNoiseAugmentation(**config.get("params", dict()))
|
||||
model = instantiate_from_config(config)
|
||||
self.low_scale_model = model.eval()
|
||||
self.low_scale_model.train = disabled_train
|
||||
for param in self.low_scale_model.parameters():
|
||||
|
Reference in New Issue
Block a user