Polish Code

This commit is contained in:
natalie_cao 2023-04-11 14:10:45 +08:00 committed by アマデウス
parent 152239bbfa
commit de84c0311a
15 changed files with 562 additions and 719 deletions

View File

@ -1,6 +1,5 @@
model: model:
base_learning_rate: 1.0e-4 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
parameterization: "v" parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
@ -19,8 +18,6 @@ model:
use_ema: False # we set this to false because this is an inference only config use_ema: False # we set this to false because this is an inference only config
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True use_checkpoint: True
use_fp16: True use_fp16: True
image_size: 32 # unused image_size: 32 # unused
@ -38,8 +35,6 @@ model:
legacy: False legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
@ -58,11 +53,7 @@ model:
num_res_blocks: 2 num_res_blocks: 2
attn_resolutions: [] attn_resolutions: []
dropout: 0.0 dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True freeze: True
layer: "penultimate" layer: "penultimate"

View File

@ -1,6 +1,5 @@
model: model:
base_learning_rate: 1.0e-4 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
linear_start: 0.00085 linear_start: 0.00085
linear_end: 0.0120 linear_end: 0.0120
@ -18,8 +17,6 @@ model:
use_ema: False # we set this to false because this is an inference only config use_ema: False # we set this to false because this is an inference only config
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True use_checkpoint: True
use_fp16: True use_fp16: True
image_size: 32 # unused image_size: 32 # unused
@ -37,8 +34,6 @@ model:
legacy: False legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
@ -57,11 +52,7 @@ model:
num_res_blocks: 2 num_res_blocks: 2
attn_resolutions: [] attn_resolutions: []
dropout: 0.0 dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True freeze: True
layer: "penultimate" layer: "penultimate"

View File

@ -19,8 +19,6 @@ model:
use_ema: False use_ema: False
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True use_checkpoint: True
image_size: 32 # unused image_size: 32 # unused
in_channels: 9 in_channels: 9
@ -37,8 +35,6 @@ model:
legacy: False legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
@ -58,18 +54,13 @@ model:
attn_resolutions: [ ] attn_resolutions: [ ]
dropout: 0.0 dropout: 0.0
lossconfig: lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True freeze: True
layer: "penultimate" layer: "penultimate"
data: data:
target: ldm.data.laion.WebDataModuleFromConfig
params:
tar_base: null # for concat as in LAION-A tar_base: null # for concat as in LAION-A
p_unsafe_threshold: 0.1 p_unsafe_threshold: 0.1
filter_word_list: "data/filters.yaml" filter_word_list: "data/filters.yaml"
@ -132,8 +123,6 @@ lightning:
every_n_train_steps: 10000 every_n_train_steps: 10000
image_logger: image_logger:
target: main.ImageLogger
params:
enable_autocast: False enable_autocast: False
disabled: False disabled: False
batch_frequency: 1000 batch_frequency: 1000

View File

@ -19,13 +19,9 @@ model:
use_ema: False use_ema: False
depth_stage_config: depth_stage_config:
target: ldm.modules.midas.api.MiDaSInference
params:
model_type: "dpt_hybrid" model_type: "dpt_hybrid"
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True use_checkpoint: True
image_size: 32 # unused image_size: 32 # unused
in_channels: 5 in_channels: 5
@ -42,8 +38,6 @@ model:
legacy: False legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
@ -63,10 +57,7 @@ model:
attn_resolutions: [ ] attn_resolutions: [ ]
dropout: 0.0 dropout: 0.0
lossconfig: lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True freeze: True
layer: "penultimate" layer: "penultimate"

View File

@ -20,16 +20,12 @@ model:
use_ema: False use_ema: False
low_scale_config: low_scale_config:
target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
params:
noise_schedule_config: # image space noise_schedule_config: # image space
linear_start: 0.0001 linear_start: 0.0001
linear_end: 0.02 linear_end: 0.02
max_noise_level: 350 max_noise_level: 350
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True use_checkpoint: True
num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
image_size: 128 image_size: 128
@ -49,8 +45,6 @@ model:
use_linear_in_transformer: True use_linear_in_transformer: True
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4 embed_dim: 4
ddconfig: ddconfig:
# attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though) # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though)
@ -64,12 +58,9 @@ model:
num_res_blocks: 2 num_res_blocks: 2
attn_resolutions: [ ] attn_resolutions: [ ]
dropout: 0.0 dropout: 0.0
lossconfig: lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True freeze: True
layer: "penultimate" layer: "penultimate"

View File

@ -1,6 +1,5 @@
model: model:
base_learning_rate: 1.0e-4 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
parameterization: "v" parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
@ -20,8 +19,6 @@ model:
use_ema: False use_ema: False
scheduler_config: # 10000 warmup steps scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ] f_start: [ 1.e-6 ]
@ -30,8 +27,6 @@ model:
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True use_checkpoint: True
use_fp16: True use_fp16: True
image_size: 32 # unused image_size: 32 # unused
@ -49,8 +44,6 @@ model:
legacy: False legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
@ -70,17 +63,12 @@ model:
attn_resolutions: [] attn_resolutions: []
dropout: 0.0 dropout: 0.0
lossconfig: lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True freeze: True
layer: "penultimate" layer: "penultimate"
data: data:
target: main.DataModuleFromConfig
params:
batch_size: 16 batch_size: 16
num_workers: 4 num_workers: 4
train: train:
@ -105,8 +93,6 @@ lightning:
precision: 16 precision: 16
auto_select_gpus: False auto_select_gpus: False
strategy: strategy:
target: strategies.ColossalAIStrategy
params:
use_chunk: True use_chunk: True
enable_distributed_storage: True enable_distributed_storage: True
placement_policy: cuda placement_policy: cuda
@ -120,8 +106,6 @@ lightning:
logger_config: logger_config:
wandb: wandb:
target: loggers.WandbLogger
params:
name: nowname name: nowname
save_dir: "/tmp/diff_log/" save_dir: "/tmp/diff_log/"
offline: opt.debug offline: opt.debug

View File

@ -1,6 +1,5 @@
model: model:
base_learning_rate: 1.0e-4 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
parameterization: "v" parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
@ -19,8 +18,6 @@ model:
use_ema: False # we set this to false because this is an inference only config use_ema: False # we set this to false because this is an inference only config
scheduler_config: # 10000 warmup steps scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ] f_start: [ 1.e-6 ]
@ -29,8 +26,6 @@ model:
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True use_checkpoint: True
use_fp16: True use_fp16: True
image_size: 32 # unused image_size: 32 # unused
@ -48,8 +43,6 @@ model:
legacy: False legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
@ -69,17 +62,13 @@ model:
attn_resolutions: [] attn_resolutions: []
dropout: 0.0 dropout: 0.0
lossconfig: lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True freeze: True
layer: "penultimate" layer: "penultimate"
data: data:
target: main.DataModuleFromConfig
params:
batch_size: 128 batch_size: 128
wrap: False wrap: False
# num_workwers should be 2 * batch_size, and total num less than 1024 # num_workwers should be 2 * batch_size, and total num less than 1024
@ -95,14 +84,12 @@ data:
lightning: lightning:
trainer: trainer:
accelerator: 'gpu' accelerator: 'gpu'
devices: 8 devices: 2
log_gpu_memory: all log_gpu_memory: all
max_epochs: 2 max_epochs: 2
precision: 16 precision: 16
auto_select_gpus: False auto_select_gpus: False
strategy: strategy:
target: strategies.ColossalAIStrategy
params:
use_chunk: True use_chunk: True
enable_distributed_storage: True enable_distributed_storage: True
placement_policy: cuda placement_policy: cuda
@ -116,8 +103,6 @@ lightning:
logger_config: logger_config:
wandb: wandb:
target: loggers.WandbLogger
params:
name: nowname name: nowname
save_dir: "/tmp/diff_log/" save_dir: "/tmp/diff_log/"
offline: opt.debug offline: opt.debug

View File

@ -1,6 +1,5 @@
model: model:
base_learning_rate: 1.0e-4 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
parameterization: "v" parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
@ -19,8 +18,6 @@ model:
use_ema: False # we set this to false because this is an inference only config use_ema: False # we set this to false because this is an inference only config
scheduler_config: # 10000 warmup steps scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ] f_start: [ 1.e-6 ]
@ -29,8 +26,6 @@ model:
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True use_checkpoint: True
use_fp16: True use_fp16: True
image_size: 32 # unused image_size: 32 # unused
@ -48,8 +43,6 @@ model:
legacy: False legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
@ -69,17 +62,12 @@ model:
attn_resolutions: [] attn_resolutions: []
dropout: 0.0 dropout: 0.0
lossconfig: lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True freeze: True
layer: "penultimate" layer: "penultimate"
data: data:
target: main.DataModuleFromConfig
params:
batch_size: 4 batch_size: 4
num_workers: 4 num_workers: 4
train: train:
@ -105,8 +93,6 @@ lightning:
precision: 16 precision: 16
auto_select_gpus: False auto_select_gpus: False
strategy: strategy:
target: strategies.ColossalAIStrategy
params:
use_chunk: True use_chunk: True
enable_distributed_storage: True enable_distributed_storage: True
placement_policy: cuda placement_policy: cuda
@ -120,8 +106,6 @@ lightning:
logger_config: logger_config:
wandb: wandb:
target: loggers.WandbLogger
params:
name: nowname name: nowname
save_dir: "/tmp/diff_log/" save_dir: "/tmp/diff_log/"
offline: opt.debug offline: opt.debug

View File

@ -1,6 +1,5 @@
model: model:
base_learning_rate: 1.0e-4 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
parameterization: "v" parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
@ -19,8 +18,6 @@ model:
use_ema: False # we set this to false because this is an inference only config use_ema: False # we set this to false because this is an inference only config
scheduler_config: # 10000 warmup steps scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ] f_start: [ 1.e-6 ]
@ -29,8 +26,6 @@ model:
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True use_checkpoint: True
use_fp16: True use_fp16: True
image_size: 32 # unused image_size: 32 # unused
@ -48,8 +43,6 @@ model:
legacy: False legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
@ -68,18 +61,12 @@ model:
num_res_blocks: 2 num_res_blocks: 2
attn_resolutions: [] attn_resolutions: []
dropout: 0.0 dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True freeze: True
layer: "penultimate" layer: "penultimate"
data: data:
target: main.DataModuleFromConfig
params:
batch_size: 128 batch_size: 128
# num_workwers should be 2 * batch_size, and the total num less than 1024 # num_workwers should be 2 * batch_size, and the total num less than 1024
# e.g. if use 8 devices, no more than 128 # e.g. if use 8 devices, no more than 128
@ -100,8 +87,6 @@ lightning:
precision: 16 precision: 16
auto_select_gpus: False auto_select_gpus: False
strategy: strategy:
target: strategies.DDPStrategy
params:
find_unused_parameters: False find_unused_parameters: False
log_every_n_steps: 2 log_every_n_steps: 2
# max_steps: 6o # max_steps: 6o
@ -111,8 +96,6 @@ lightning:
logger_config: logger_config:
wandb: wandb:
target: loggers.WandbLogger
params:
name: nowname name: nowname
save_dir: "/data2/tmp/diff_log/" save_dir: "/data2/tmp/diff_log/"
offline: opt.debug offline: opt.debug

View File

@ -1,16 +1,13 @@
import torch import torch
try:
import lightning.pytorch as pl import lightning.pytorch as pl
except:
import pytorch_lightning as pl
import torch.nn.functional as F from torch import nn
from torch.nn import functional as F
from torch.nn import Identity
from contextlib import contextmanager from contextlib import contextmanager
from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config
from ldm.modules.ema import LitEma from ldm.modules.ema import LitEma
@ -32,7 +29,7 @@ class AutoencoderKL(pl.LightningModule):
self.image_key = image_key self.image_key = image_key
self.encoder = Encoder(**ddconfig) self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig) self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig) self.loss = Identity()
assert ddconfig["double_z"] assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)

View File

@ -9,9 +9,10 @@ from copy import deepcopy
from einops import rearrange from einops import rearrange
from glob import glob from glob import glob
from natsort import natsorted from natsort import natsorted
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.lr_scheduler import LambdaLinearScheduler
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config from ldm.util import log_txt_as_img, default, ismap
__models__ = { __models__ = {
'class_label': EncoderUNetModel, 'class_label': EncoderUNetModel,
@ -86,7 +87,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
print(f"Unexpected Keys: {unexpected}") print(f"Unexpected Keys: {unexpected}")
def load_diffusion(self): def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config) model = LatentDiffusion(**self.diffusion_config.get('params',dict()))
self.diffusion_model = model.eval() self.diffusion_model = model.eval()
self.diffusion_model.train = disabled_train self.diffusion_model.train = disabled_train
for param in self.diffusion_model.parameters(): for param in self.diffusion_model.parameters():
@ -221,7 +222,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
if self.use_scheduler: if self.use_scheduler:
scheduler = instantiate_from_config(self.scheduler_config) scheduler = LambdaLinearScheduler(**self.scheduler_config.get('params',dict()))
print("Setting up LambdaLR scheduler...") print("Setting up LambdaLR scheduler...")
scheduler = [ scheduler = [

View File

@ -22,19 +22,22 @@ from contextlib import contextmanager, nullcontext
from functools import partial from functools import partial
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.lr_scheduler import LambdaLinearScheduler
from ldm.models.autoencoder import * from ldm.models.autoencoder import *
from ldm.models.autoencoder import AutoencoderKL, IdentityFirstStage from ldm.models.autoencoder import AutoencoderKL, IdentityFirstStage
from ldm.models.diffusion.ddim import * from ldm.models.diffusion.ddim import *
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.modules.midas.api import MiDaSInference
from ldm.modules.diffusionmodules.model import * from ldm.modules.diffusionmodules.model import *
from ldm.modules.diffusionmodules.model import Decoder, Encoder, Model from ldm.modules.diffusionmodules.model import Decoder, Encoder, Model
from ldm.modules.diffusionmodules.openaimodel import * from ldm.modules.diffusionmodules.openaimodel import *
from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d, UNetModel
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution, normal_kl from ldm.modules.distributions.distributions import DiagonalGaussianDistribution, normal_kl
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from ldm.modules.ema import LitEma from ldm.modules.ema import LitEma
from ldm.modules.encoders.modules import * from ldm.modules.encoders.modules import *
from ldm.util import count_params, default, exists, instantiate_from_config, isimage, ismap, log_txt_as_img, mean_flat from ldm.util import count_params, default, exists, isimage, ismap, log_txt_as_img, mean_flat
from omegaconf import ListConfig from omegaconf import ListConfig
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from torchvision.utils import make_grid from torchvision.utils import make_grid
@ -690,7 +693,7 @@ class LatentDiffusion(DDPM):
self.make_cond_schedule() self.make_cond_schedule()
def instantiate_first_stage(self, config): def instantiate_first_stage(self, config):
model = instantiate_from_config(config) model = AutoencoderKL(**config)
self.first_stage_model = model.eval() self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters(): for param in self.first_stage_model.parameters():
@ -706,15 +709,13 @@ class LatentDiffusion(DDPM):
self.cond_stage_model = None self.cond_stage_model = None
# self.be_unconditional = True # self.be_unconditional = True
else: else:
model = instantiate_from_config(config) model = FrozenOpenCLIPEmbedder(**config)
self.cond_stage_model = model.eval() self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters(): for param in self.cond_stage_model.parameters():
param.requires_grad = False param.requires_grad = False
else: else:
assert config != '__is_first_stage__' model = FrozenOpenCLIPEmbedder(**config)
assert config != '__is_unconditional__'
model = instantiate_from_config(config)
self.cond_stage_model = model self.cond_stage_model = model
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
@ -1479,8 +1480,7 @@ class LatentDiffusion(DDPM):
# opt = torch.optim.AdamW(params, lr=lr) # opt = torch.optim.AdamW(params, lr=lr)
if self.use_scheduler: if self.use_scheduler:
assert 'target' in self.scheduler_config scheduler = LambdaLinearScheduler(**self.scheduler_config)
scheduler = instantiate_from_config(self.scheduler_config)
rank_zero_info("Setting up LambdaLR scheduler...") rank_zero_info("Setting up LambdaLR scheduler...")
scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}] scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}]
@ -1502,7 +1502,7 @@ class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key): def __init__(self, diff_model_config, conditioning_key):
super().__init__() super().__init__()
self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
self.diffusion_model = instantiate_from_config(diff_model_config) self.diffusion_model = UNetModel(**diff_model_config)
self.conditioning_key = conditioning_key self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
@ -1551,7 +1551,7 @@ class LatentUpscaleDiffusion(LatentDiffusion):
self.noise_level_key = noise_level_key self.noise_level_key = noise_level_key
def instantiate_low_stage(self, config): def instantiate_low_stage(self, config):
model = instantiate_from_config(config) model = ImageConcatWithNoiseAugmentation(**config)
self.low_scale_model = model.eval() self.low_scale_model = model.eval()
self.low_scale_model.train = disabled_train self.low_scale_model.train = disabled_train
for param in self.low_scale_model.parameters(): for param in self.low_scale_model.parameters():
@ -1933,7 +1933,7 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs): def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
super().__init__(concat_keys=concat_keys, *args, **kwargs) super().__init__(concat_keys=concat_keys, *args, **kwargs)
self.depth_model = instantiate_from_config(depth_stage_config) self.depth_model = MiDaSInference(**depth_stage_config)
self.depth_stage_key = concat_keys[0] self.depth_stage_key = concat_keys[0]
@torch.no_grad() @torch.no_grad()
@ -2006,7 +2006,7 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
self.low_scale_key = low_scale_key self.low_scale_key = low_scale_key
def instantiate_low_stage(self, config): def instantiate_low_stage(self, config):
model = instantiate_from_config(config) model = ImageConcatWithNoiseAugmentation(**config)
self.low_scale_model = model.eval() self.low_scale_model = model.eval()
self.low_scale_model.train = disabled_train self.low_scale_model.train = disabled_train
for param in self.low_scale_model.parameters(): for param in self.low_scale_model.parameters():

View File

@ -10,11 +10,8 @@ import time
import numpy as np import numpy as np
import torch import torch
import torchvision import torchvision
try:
import lightning.pytorch as pl import lightning.pytorch as pl
except:
import pytorch_lightning as pl
from functools import partial from functools import partial
@ -23,19 +20,15 @@ from packaging import version
from PIL import Image from PIL import Image
from prefetch_generator import BackgroundGenerator from prefetch_generator import BackgroundGenerator
from torch.utils.data import DataLoader, Dataset, Subset, random_split from torch.utils.data import DataLoader, Dataset, Subset, random_split
from ldm.models.diffusion.ddpm import LatentDiffusion
try:
from lightning.pytorch import seed_everything from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.trainer import Trainer from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger
from lightning.pytorch.strategies import ColossalAIStrategy,DDPStrategy
LIGHTNING_PACK_NAME = "lightning.pytorch." LIGHTNING_PACK_NAME = "lightning.pytorch."
except:
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
LIGHTNING_PACK_NAME = "pytorch_lightning."
from ldm.data.base import Txt2ImgIterableBaseDataset from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
@ -687,86 +680,72 @@ if __name__ == "__main__":
config.model["params"].update({"ckpt": ckpt}) config.model["params"].update({"ckpt": ckpt})
rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"])) rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
model = instantiate_from_config(config.model) model = LatentDiffusion(**config.model.get("params", dict()))
# trainer and callbacks # trainer and callbacks
trainer_kwargs = dict() trainer_kwargs = dict()
# config the logger # config the logger
# Default logger configs to log training metrics during the training process. # Default logger configs to log training metrics during the training process.
# These loggers are specified as targets in the dictionary, along with the configuration settings specific to each logger.
default_logger_cfgs = { default_logger_cfgs = {
"wandb": { "wandb": {
"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger",
"params": {
"name": nowname, "name": nowname,
"save_dir": logdir, "save_dir": logdir,
"offline": opt.debug, "offline": opt.debug,
"id": nowname, "id": nowname,
} }
}, ,
"tensorboard": { "tensorboard": {
"target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger",
"params": {
"save_dir": logdir, "save_dir": logdir,
"name": "diff_tb", "name": "diff_tb",
"log_graph": True "log_graph": True
} }
} }
}
# Set up the logger for TensorBoard # Set up the logger for TensorBoard
default_logger_cfg = default_logger_cfgs["tensorboard"] default_logger_cfg = default_logger_cfgs["tensorboard"]
if "logger" in lightning_config: if "logger" in lightning_config:
logger_cfg = lightning_config.logger logger_cfg = lightning_config.logger
trainer_kwargs["logger"] = WandbLogger(**logger_cfg)
else: else:
logger_cfg = default_logger_cfg logger_cfg = default_logger_cfg
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) trainer_kwargs["logger"] = TensorBoardLogger(**logger_cfg)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
# config the strategy, defualt is ddp # config the strategy, defualt is ddp
if "strategy" in trainer_config: if "strategy" in trainer_config:
strategy_cfg = trainer_config["strategy"] strategy_cfg = trainer_config["strategy"]
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"] trainer_kwargs["strategy"] = ColossalAIStrategy(**strategy_cfg)
else: else:
strategy_cfg = { strategy_cfg = {"find_unused_parameters": False}
"target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy", trainer_kwargs["strategy"] = DDPStrategy(**strategy_cfg)
"params": {
"find_unused_parameters": False
}
}
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
# Set up ModelCheckpoint callback to save best models # Set up ModelCheckpoint callback to save best models
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models # specify which metric is used to determine best models
default_modelckpt_cfg = { default_modelckpt_cfg = {
"target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir, "dirpath": ckptdir,
"filename": "{epoch:06}", "filename": "{epoch:06}",
"verbose": True, "verbose": True,
"save_last": True, "save_last": True,
} }
}
if hasattr(model, "monitor"): if hasattr(model, "monitor"):
default_modelckpt_cfg["params"]["monitor"] = model.monitor default_modelckpt_cfg["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 3 default_modelckpt_cfg["save_top_k"] = 3
if "modelcheckpoint" in lightning_config: if "modelcheckpoint" in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint modelckpt_cfg = lightning_config.modelcheckpoint["params"]
else: else:
modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
if version.parse(pl.__version__) < version.parse('1.4.0'): if version.parse(pl.__version__) < version.parse('1.4.0'):
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) trainer_kwargs["checkpoint_callback"] = ModelCheckpoint(**modelckpt_cfg)
# Set up various callbacks, including logging, learning rate monitoring, and CUDA management #Create an empty OmegaConf configuration object
# add callback which sets up log directory
default_callbacks_cfg = { callbacks_cfg = OmegaConf.create()
"setup_callback": { # callback to set up the training
"target": "main.SetupCallback", #Instantiate items according to the configs
"params": { trainer_kwargs.setdefault("callbacks", [])
setup_callback_config = {
"resume": opt.resume, # resume training if applicable "resume": opt.resume, # resume training if applicable
"now": now, "now": now,
"logdir": logdir, # directory to save the log file "logdir": logdir, # directory to save the log file
@ -775,43 +754,23 @@ if __name__ == "__main__":
"config": config, # configuration dictionary "config": config, # configuration dictionary
"lightning_config": lightning_config, # LightningModule configuration "lightning_config": lightning_config, # LightningModule configuration
} }
}, trainer_kwargs["callbacks"].append(SetupCallback(**setup_callback_config))
"image_logger": { # callback to log image data
"target": "main.ImageLogger", image_logger_config = {
"params": {
"batch_frequency": 750, # how frequently to log images "batch_frequency": 750, # how frequently to log images
"max_images": 4, # maximum number of images to log "max_images": 4, # maximum number of images to log
"clamp": True # whether to clamp pixel values to [0,1] "clamp": True # whether to clamp pixel values to [0,1]
} }
}, trainer_kwargs["callbacks"].append(ImageLogger(**image_logger_config))
"learning_rate_logger": { # callback to log learning rate
"target": "main.LearningRateMonitor", learning_rate_logger_config = {
"params": {
"logging_interval": "step", # logging frequency (either 'step' or 'epoch') "logging_interval": "step", # logging frequency (either 'step' or 'epoch')
# "log_momentum": True # whether to log momentum (currently commented out) # "log_momentum": True # whether to log momentum (currently commented out)
} }
}, trainer_kwargs["callbacks"].append(LearningRateMonitor(**learning_rate_logger_config))
"cuda_callback": { # callback to handle CUDA-related operations
"target": "main.CUDACallback"
},
}
# If the LightningModule configuration has specified callbacks, use those metrics_over_trainsteps_checkpoint_config= {
# Otherwise, create an empty OmegaConf configuration object
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
else:
callbacks_cfg = OmegaConf.create()
# If the 'metrics_over_trainsteps_checkpoint' callback is specified in the
# LightningModule configuration, update the default callbacks configuration
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
print(
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
default_metrics_over_trainsteps_ckpt_dict = {
'metrics_over_trainsteps_checkpoint': {
"target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint',
'params': {
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
"filename": "{epoch:06}-{step:09}", "filename": "{epoch:06}-{step:09}",
"verbose": True, "verbose": True,
@ -819,21 +778,16 @@ if __name__ == "__main__":
'every_n_train_steps': 10000, 'every_n_train_steps': 10000,
'save_weights_only': True 'save_weights_only': True
} }
} trainer_kwargs["callbacks"].append(ModelCheckpoint(**metrics_over_trainsteps_checkpoint_config))
} trainer_kwargs["callbacks"].append(CUDACallback())
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
# Merge the default callbacks configuration with the specified callbacks configuration, and instantiate the callbacks
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
# Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory # Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir trainer.logdir = logdir
# Create a data module based on the configuration file # Create a data module based on the configuration file
data = instantiate_from_config(config.data) data = DataModuleFromConfig(**config.data)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is. # calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though # lightning still takes care of proper multiprocessing though
@ -846,7 +800,7 @@ if __name__ == "__main__":
# Configure learning rate based on the batch size, base learning rate and number of GPUs # Configure learning rate based on the batch size, base learning rate and number of GPUs
# If scale_lr is true, calculate the learning rate based on additional factors # If scale_lr is true, calculate the learning rate based on additional factors
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate bs, base_lr = config.data.batch_size, config.model.base_learning_rate
if not cpu: if not cpu:
ngpu = trainer_config["devices"] ngpu = trainer_config["devices"]
else: else:

View File

@ -7,8 +7,9 @@ from datetime import datetime
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
import torch import torch
from ldm.util import instantiate_from_config
from main import get_parser from main import get_parser
from ldm.modules.diffusionmodules.openaimodel import UNetModel
if __name__ == "__main__": if __name__ == "__main__":
with torch.no_grad(): with torch.no_grad():
@ -17,7 +18,7 @@ if __name__ == "__main__":
config = f.read() config = f.read()
base_config = yaml.load(config, Loader=yaml.FullLoader) base_config = yaml.load(config, Loader=yaml.FullLoader)
unet_config = base_config['model']['params']['unet_config'] unet_config = base_config['model']['params']['unet_config']
diffusion_model = instantiate_from_config(unet_config).to("cuda:0") diffusion_model = UNetModel(**unet_config).to("cuda:0")
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"/data/scratch/diffuser/stable-diffusion-v1-4" "/data/scratch/diffuser/stable-diffusion-v1-4"

View File

@ -3,3 +3,4 @@ TRANSFORMERS_OFFLINE=1
DIFFUSERS_OFFLINE=1 DIFFUSERS_OFFLINE=1
python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt diffuser_root_dir/512-base-ema.ckpt python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt diffuser_root_dir/512-base-ema.ckpt