Polish Code

This commit is contained in:
natalie_cao
2023-04-11 14:10:45 +08:00
committed by アマデウス
parent 152239bbfa
commit de84c0311a
15 changed files with 562 additions and 719 deletions

View File

@@ -10,11 +10,8 @@ import time
import numpy as np
import torch
import torchvision
import lightning.pytorch as pl
try:
import lightning.pytorch as pl
except:
import pytorch_lightning as pl
from functools import partial
@@ -23,19 +20,15 @@ from packaging import version
from PIL import Image
from prefetch_generator import BackgroundGenerator
from torch.utils.data import DataLoader, Dataset, Subset, random_split
from ldm.models.diffusion.ddpm import LatentDiffusion
try:
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
LIGHTNING_PACK_NAME = "lightning.pytorch."
except:
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
LIGHTNING_PACK_NAME = "pytorch_lightning."
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger
from lightning.pytorch.strategies import ColossalAIStrategy,DDPStrategy
LIGHTNING_PACK_NAME = "lightning.pytorch."
from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
@@ -687,153 +680,114 @@ if __name__ == "__main__":
config.model["params"].update({"ckpt": ckpt})
rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
model = instantiate_from_config(config.model)
model = LatentDiffusion(**config.model.get("params", dict()))
# trainer and callbacks
trainer_kwargs = dict()
# config the logger
# Default logger configs to log training metrics during the training process.
# These loggers are specified as targets in the dictionary, along with the configuration settings specific to each logger.
default_logger_cfgs = {
"wandb": {
"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger",
"params": {
"name": nowname,
"save_dir": logdir,
"offline": opt.debug,
"id": nowname,
}
},
,
"tensorboard": {
"target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger",
"params": {
"save_dir": logdir,
"name": "diff_tb",
"log_graph": True
}
}
}
# Set up the logger for TensorBoard
default_logger_cfg = default_logger_cfgs["tensorboard"]
if "logger" in lightning_config:
logger_cfg = lightning_config.logger
trainer_kwargs["logger"] = WandbLogger(**logger_cfg)
else:
logger_cfg = default_logger_cfg
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
trainer_kwargs["logger"] = TensorBoardLogger(**logger_cfg)
# config the strategy, defualt is ddp
if "strategy" in trainer_config:
strategy_cfg = trainer_config["strategy"]
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
trainer_kwargs["strategy"] = ColossalAIStrategy(**strategy_cfg)
else:
strategy_cfg = {
"target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy",
"params": {
"find_unused_parameters": False
}
}
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
strategy_cfg = {"find_unused_parameters": False}
trainer_kwargs["strategy"] = DDPStrategy(**strategy_cfg)
# Set up ModelCheckpoint callback to save best models
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg = {
"target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
}
}
if hasattr(model, "monitor"):
default_modelckpt_cfg["params"]["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 3
default_modelckpt_cfg["monitor"] = model.monitor
default_modelckpt_cfg["save_top_k"] = 3
if "modelcheckpoint" in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint
modelckpt_cfg = lightning_config.modelcheckpoint["params"]
else:
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
if version.parse(pl.__version__) < version.parse('1.4.0'):
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
trainer_kwargs["checkpoint_callback"] = ModelCheckpoint(**modelckpt_cfg)
# Set up various callbacks, including logging, learning rate monitoring, and CUDA management
# add callback which sets up log directory
default_callbacks_cfg = {
"setup_callback": { # callback to set up the training
"target": "main.SetupCallback",
"params": {
"resume": opt.resume, # resume training if applicable
"now": now,
"logdir": logdir, # directory to save the log file
"ckptdir": ckptdir, # directory to save the checkpoint file
"cfgdir": cfgdir, # directory to save the configuration file
"config": config, # configuration dictionary
"lightning_config": lightning_config, # LightningModule configuration
}
},
"image_logger": { # callback to log image data
"target": "main.ImageLogger",
"params": {
"batch_frequency": 750, # how frequently to log images
"max_images": 4, # maximum number of images to log
"clamp": True # whether to clamp pixel values to [0,1]
}
},
"learning_rate_logger": { # callback to log learning rate
"target": "main.LearningRateMonitor",
"params": {
"logging_interval": "step", # logging frequency (either 'step' or 'epoch')
# "log_momentum": True # whether to log momentum (currently commented out)
}
},
"cuda_callback": { # callback to handle CUDA-related operations
"target": "main.CUDACallback"
},
}
#Create an empty OmegaConf configuration object
# If the LightningModule configuration has specified callbacks, use those
# Otherwise, create an empty OmegaConf configuration object
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
else:
callbacks_cfg = OmegaConf.create()
# If the 'metrics_over_trainsteps_checkpoint' callback is specified in the
# LightningModule configuration, update the default callbacks configuration
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
print(
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
default_metrics_over_trainsteps_ckpt_dict = {
'metrics_over_trainsteps_checkpoint': {
"target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint',
'params': {
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
"filename": "{epoch:06}-{step:09}",
"verbose": True,
'save_top_k': -1,
'every_n_train_steps': 10000,
'save_weights_only': True
}
}
callbacks_cfg = OmegaConf.create()
#Instantiate items according to the configs
trainer_kwargs.setdefault("callbacks", [])
setup_callback_config = {
"resume": opt.resume, # resume training if applicable
"now": now,
"logdir": logdir, # directory to save the log file
"ckptdir": ckptdir, # directory to save the checkpoint file
"cfgdir": cfgdir, # directory to save the configuration file
"config": config, # configuration dictionary
"lightning_config": lightning_config, # LightningModule configuration
}
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
trainer_kwargs["callbacks"].append(SetupCallback(**setup_callback_config))
# Merge the default callbacks configuration with the specified callbacks configuration, and instantiate the callbacks
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
image_logger_config = {
"batch_frequency": 750, # how frequently to log images
"max_images": 4, # maximum number of images to log
"clamp": True # whether to clamp pixel values to [0,1]
}
trainer_kwargs["callbacks"].append(ImageLogger(**image_logger_config))
learning_rate_logger_config = {
"logging_interval": "step", # logging frequency (either 'step' or 'epoch')
# "log_momentum": True # whether to log momentum (currently commented out)
}
trainer_kwargs["callbacks"].append(LearningRateMonitor(**learning_rate_logger_config))
metrics_over_trainsteps_checkpoint_config= {
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
"filename": "{epoch:06}-{step:09}",
"verbose": True,
'save_top_k': -1,
'every_n_train_steps': 10000,
'save_weights_only': True
}
trainer_kwargs["callbacks"].append(ModelCheckpoint(**metrics_over_trainsteps_checkpoint_config))
trainer_kwargs["callbacks"].append(CUDACallback())
# Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir
# Create a data module based on the configuration file
data = instantiate_from_config(config.data)
data = DataModuleFromConfig(**config.data)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
@@ -846,7 +800,7 @@ if __name__ == "__main__":
# Configure learning rate based on the batch size, base learning rate and number of GPUs
# If scale_lr is true, calculate the learning rate based on additional factors
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
bs, base_lr = config.data.batch_size, config.model.base_learning_rate
if not cpu:
ngpu = trainer_config["devices"]
else: