support stable diffusion v2

This commit is contained in:
Fazzie
2022-12-12 17:35:23 +08:00
parent e99edfcb51
commit cea4292ae5
51 changed files with 5597 additions and 3302 deletions

View File

@@ -1,54 +1,48 @@
import argparse, os, sys, datetime, glob, importlib, csv
import numpy as np
import argparse
import csv
import datetime
import glob
import importlib
import os
import sys
import time
import numpy as np
import torch
import torchvision
import lightning.pytorch as pl
from packaging import version
from omegaconf import OmegaConf
from torch.utils.data import random_split, DataLoader, Dataset, Subset
try:
import lightning.pytorch as pl
except:
import pytorch_lightning as pl
from functools import partial
from omegaconf import OmegaConf
from packaging import version
from PIL import Image
# from lightning.pytorch.strategies.colossalai import ColossalAIStrategy
# from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from prefetch_generator import BackgroundGenerator
from torch.utils.data import DataLoader, Dataset, Subset, random_split
from lightning.pytorch import seed_everything
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from lightning.pytorch.utilities import rank_zero_info
from diffusers.models.unet_2d import UNet2DModel
from clip.model import Bottleneck
from transformers.models.clip.modeling_clip import CLIPTextTransformer
try:
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
LIGHTNING_PACK_NAME = "lightning.pytorch."
except:
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
LIGHTNING_PACK_NAME = "pytorch_lightning."
from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
import clip
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
import kornia
from ldm.modules.x_transformer import *
from ldm.modules.encoders.modules import *
from taming.modules.diffusionmodules.model import ResnetBlock
from taming.modules.transformer.mingpt import *
from taming.modules.transformer.permuter import *
# from ldm.modules.attention import enable_flash_attentions
from ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.models.autoencoder import AutoencoderKL
from ldm.models.autoencoder import *
from ldm.models.diffusion.ddim import *
from ldm.modules.diffusionmodules.openaimodel import *
from ldm.modules.diffusionmodules.model import *
from ldm.modules.diffusionmodules.model import Decoder, Encoder, Up_module, Down_module, Mid_module, temb_module
from ldm.modules.attention import enable_flash_attention
class DataLoaderX(DataLoader):
def __iter__(self):
@@ -56,6 +50,7 @@ class DataLoaderX(DataLoader):
def get_parser(**parser_kwargs):
def str2bool(v):
if isinstance(v, bool):
return v
@@ -91,7 +86,7 @@ def get_parser(**parser_kwargs):
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default=list(),
)
parser.add_argument(
@@ -111,11 +106,7 @@ def get_parser(**parser_kwargs):
nargs="?",
help="disable test",
)
parser.add_argument(
"-p",
"--project",
help="name of new or path to existing project"
)
parser.add_argument("-p", "--project", help="name of new or path to existing project")
parser.add_argument(
"-d",
"--debug",
@@ -210,8 +201,17 @@ def worker_init_fn(_):
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
def __init__(self,
batch_size,
train=None,
validation=None,
test=None,
predict=None,
wrap=False,
num_workers=None,
shuffle_test_loader=False,
use_worker_init_fn=False,
shuffle_val_dataloader=False):
super().__init__()
self.batch_size = batch_size
@@ -237,9 +237,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
instantiate_from_config(data_cfg)
def setup(self, stage=None):
self.datasets = dict(
(k, instantiate_from_config(self.dataset_configs[k]))
for k in self.dataset_configs)
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
if self.wrap:
for k in self.datasets:
self.datasets[k] = WrappedDataset(self.datasets[k])
@@ -250,9 +248,11 @@ class DataModuleFromConfig(pl.LightningDataModule):
init_fn = worker_init_fn
else:
init_fn = None
return DataLoaderX(self.datasets["train"], batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
worker_init_fn=init_fn)
return DataLoaderX(self.datasets["train"],
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False if is_iterable_dataset else True,
worker_init_fn=init_fn)
def _val_dataloader(self, shuffle=False):
if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
@@ -260,10 +260,10 @@ class DataModuleFromConfig(pl.LightningDataModule):
else:
init_fn = None
return DataLoaderX(self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle)
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle)
def _test_dataloader(self, shuffle=False):
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
@@ -275,19 +275,25 @@ class DataModuleFromConfig(pl.LightningDataModule):
# do not shuffle dataloader for iterable dataset
shuffle = shuffle and (not is_iterable_dataset)
return DataLoaderX(self.datasets["test"], batch_size=self.batch_size,
num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
return DataLoaderX(self.datasets["test"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle)
def _predict_dataloader(self, shuffle=False):
if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoaderX(self.datasets["predict"], batch_size=self.batch_size,
num_workers=self.num_workers, worker_init_fn=init_fn)
return DataLoaderX(self.datasets["predict"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn)
class SetupCallback(Callback):
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
super().__init__()
self.resume = resume
@@ -317,8 +323,7 @@ class SetupCallback(Callback):
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
print("Project config")
print(OmegaConf.to_yaml(self.config))
OmegaConf.save(self.config,
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
print("Lightning config")
print(OmegaConf.to_yaml(self.lightning_config))
@@ -338,8 +343,16 @@ class SetupCallback(Callback):
class ImageLogger(Callback):
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
def __init__(self,
batch_frequency,
max_images,
clamp=True,
increase_log_steps=True,
rescale=True,
disabled=False,
log_on_batch_idx=False,
log_first_step=False,
log_images_kwargs=None):
super().__init__()
self.rescale = rescale
@@ -348,7 +361,7 @@ class ImageLogger(Callback):
self.logger_log_images = {
pl.loggers.CSVLogger: self._testtube,
}
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
@@ -361,39 +374,30 @@ class ImageLogger(Callback):
def _testtube(self, pl_module, images, batch_idx, split):
for k in images:
grid = torchvision.utils.make_grid(images[k])
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
tag = f"{split}/{k}"
pl_module.logger.experiment.add_image(
tag, grid,
global_step=pl_module.global_step)
pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step)
@rank_zero_only
def log_local(self, save_dir, split, images,
global_step, current_epoch, batch_idx):
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "images", split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
if self.rescale:
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
k,
global_step,
current_epoch,
batch_idx)
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split="train"):
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
hasattr(pl_module, "log_images") and
callable(pl_module.log_images) and
self.max_images > 0):
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0):
logger = type(pl_module.logger)
is_train = pl_module.training
@@ -411,8 +415,8 @@ class ImageLogger(Callback):
if self.clamp:
images[k] = torch.clamp(images[k], -1., 1.)
self.log_local(pl_module.logger.save_dir, split, images,
pl_module.global_step, pl_module.current_epoch, batch_idx)
self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch,
batch_idx)
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
logger_log_images(pl_module, images, pl_module.global_step, split)
@@ -421,8 +425,8 @@ class ImageLogger(Callback):
pl_module.train()
def check_frequency(self, check_idx):
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
check_idx > 0 or self.log_first_step):
if ((check_idx % self.batch_freq) == 0 or
(check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step):
try:
self.log_steps.pop(0)
except IndexError as e:
@@ -461,7 +465,7 @@ class CUDACallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
torch.cuda.synchronize(trainer.strategy.root_device.index)
max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2 ** 20
max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20
epoch_time = time.time() - self.start_time
try:
@@ -528,13 +532,9 @@ if __name__ == "__main__":
opt, unknown = parser.parse_known_args()
if opt.name and opt.resume:
raise ValueError(
"-n/--name and -r/--resume cannot be specified both."
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint"
)
if opt.flash:
enable_flash_attention()
raise ValueError("-n/--name and -r/--resume cannot be specified both."
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint")
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
@@ -578,7 +578,7 @@ if __name__ == "__main__":
lightning_config = config.pop("lightning", OmegaConf.create())
# merge trainer cli with config
trainer_config = lightning_config.get("trainer", OmegaConf.create())
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
@@ -601,7 +601,7 @@ if __name__ == "__main__":
else:
config.model["params"].update({"use_fp16": False})
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
model = instantiate_from_config(config.model)
# trainer and callbacks
trainer_kwargs = dict()
@@ -610,7 +610,7 @@ if __name__ == "__main__":
# default logger configs
default_logger_cfgs = {
"wandb": {
"target": "lightning.pytorch.loggers.WandbLogger",
"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger",
"params": {
"name": nowname,
"save_dir": logdir,
@@ -618,9 +618,9 @@ if __name__ == "__main__":
"id": nowname,
}
},
"tensorboard":{
"target": "lightning.pytorch.loggers.TensorBoardLogger",
"params":{
"tensorboard": {
"target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger",
"params": {
"save_dir": logdir,
"name": "diff_tb",
"log_graph": True
@@ -640,9 +640,10 @@ if __name__ == "__main__":
if "strategy" in trainer_config:
strategy_cfg = trainer_config["strategy"]
print("Using strategy: {}".format(strategy_cfg["target"]))
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
else:
strategy_cfg = {
"target": "lightning.pytorch.strategies.DDPStrategy",
"target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy",
"params": {
"find_unused_parameters": False
}
@@ -654,7 +655,7 @@ if __name__ == "__main__":
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg = {
"target": "lightning.pytorch.callbacks.ModelCheckpoint",
"target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch:06}",
@@ -670,7 +671,7 @@ if __name__ == "__main__":
if "modelcheckpoint" in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint
else:
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
if version.parse(pl.__version__) < version.parse('1.4.0'):
@@ -702,7 +703,7 @@ if __name__ == "__main__":
"target": "main.LearningRateMonitor",
"params": {
"logging_interval": "step",
# "log_momentum": True
# "log_momentum": True
}
},
"cuda_callback": {
@@ -721,17 +722,17 @@ if __name__ == "__main__":
print(
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
default_metrics_over_trainsteps_ckpt_dict = {
'metrics_over_trainsteps_checkpoint':
{"target": 'lightning.pytorch.callbacks.ModelCheckpoint',
'params': {
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
"filename": "{epoch:06}-{step:09}",
"verbose": True,
'save_top_k': -1,
'every_n_train_steps': 10000,
'save_weights_only': True
}
}
'metrics_over_trainsteps_checkpoint': {
"target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint',
'params': {
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
"filename": "{epoch:06}-{step:09}",
"verbose": True,
'save_top_k': -1,
'every_n_train_steps': 10000,
'save_weights_only': True
}
}
}
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
@@ -744,7 +745,7 @@ if __name__ == "__main__":
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir ###
trainer.logdir = logdir ###
# data
data = instantiate_from_config(config.data)
@@ -772,14 +773,13 @@ if __name__ == "__main__":
if opt.scale_lr:
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
print(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
.format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
else:
model.learning_rate = base_lr
print("++++ NOT USING LR SCALING ++++")
print(f"Setting learning rate to {model.learning_rate:.2e}")
# allow checkpointing via USR1
def melk(*args, **kwargs):
# run all checkpoint hooks
@@ -788,13 +788,11 @@ if __name__ == "__main__":
ckpt_path = os.path.join(ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def divein(*args, **kwargs):
if trainer.global_rank == 0:
import pudb;
import pudb
pudb.set_trace()
import signal
signal.signal(signal.SIGUSR1, melk)
@@ -803,8 +801,6 @@ if __name__ == "__main__":
# run
if opt.train:
try:
for name, m in model.named_parameters():
print(name)
trainer.fit(model, data)
except Exception:
melk()