mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
Polish Code
This commit is contained in:
@@ -10,11 +10,8 @@ import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import lightning.pytorch as pl
|
||||
|
||||
try:
|
||||
import lightning.pytorch as pl
|
||||
except:
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from functools import partial
|
||||
|
||||
@@ -23,19 +20,15 @@ from packaging import version
|
||||
from PIL import Image
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
from torch.utils.data import DataLoader, Dataset, Subset, random_split
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
|
||||
try:
|
||||
from lightning.pytorch import seed_everything
|
||||
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
|
||||
from lightning.pytorch.trainer import Trainer
|
||||
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
|
||||
LIGHTNING_PACK_NAME = "lightning.pytorch."
|
||||
except:
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
LIGHTNING_PACK_NAME = "pytorch_lightning."
|
||||
from lightning.pytorch import seed_everything
|
||||
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
|
||||
from lightning.pytorch.trainer import Trainer
|
||||
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
|
||||
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger
|
||||
from lightning.pytorch.strategies import ColossalAIStrategy,DDPStrategy
|
||||
LIGHTNING_PACK_NAME = "lightning.pytorch."
|
||||
|
||||
from ldm.data.base import Txt2ImgIterableBaseDataset
|
||||
from ldm.util import instantiate_from_config
|
||||
@@ -687,153 +680,114 @@ if __name__ == "__main__":
|
||||
config.model["params"].update({"ckpt": ckpt})
|
||||
rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
model = LatentDiffusion(**config.model.get("params", dict()))
|
||||
# trainer and callbacks
|
||||
trainer_kwargs = dict()
|
||||
|
||||
# config the logger
|
||||
# Default logger configs to log training metrics during the training process.
|
||||
# These loggers are specified as targets in the dictionary, along with the configuration settings specific to each logger.
|
||||
default_logger_cfgs = {
|
||||
"wandb": {
|
||||
"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger",
|
||||
"params": {
|
||||
"name": nowname,
|
||||
"save_dir": logdir,
|
||||
"offline": opt.debug,
|
||||
"id": nowname,
|
||||
}
|
||||
},
|
||||
,
|
||||
"tensorboard": {
|
||||
"target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger",
|
||||
"params": {
|
||||
"save_dir": logdir,
|
||||
"name": "diff_tb",
|
||||
"log_graph": True
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Set up the logger for TensorBoard
|
||||
default_logger_cfg = default_logger_cfgs["tensorboard"]
|
||||
if "logger" in lightning_config:
|
||||
logger_cfg = lightning_config.logger
|
||||
trainer_kwargs["logger"] = WandbLogger(**logger_cfg)
|
||||
else:
|
||||
logger_cfg = default_logger_cfg
|
||||
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
||||
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
||||
trainer_kwargs["logger"] = TensorBoardLogger(**logger_cfg)
|
||||
|
||||
# config the strategy, defualt is ddp
|
||||
if "strategy" in trainer_config:
|
||||
strategy_cfg = trainer_config["strategy"]
|
||||
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
|
||||
trainer_kwargs["strategy"] = ColossalAIStrategy(**strategy_cfg)
|
||||
else:
|
||||
strategy_cfg = {
|
||||
"target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy",
|
||||
"params": {
|
||||
"find_unused_parameters": False
|
||||
}
|
||||
}
|
||||
|
||||
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
|
||||
strategy_cfg = {"find_unused_parameters": False}
|
||||
trainer_kwargs["strategy"] = DDPStrategy(**strategy_cfg)
|
||||
|
||||
# Set up ModelCheckpoint callback to save best models
|
||||
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
||||
# specify which metric is used to determine best models
|
||||
default_modelckpt_cfg = {
|
||||
"target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint",
|
||||
"params": {
|
||||
"dirpath": ckptdir,
|
||||
"filename": "{epoch:06}",
|
||||
"verbose": True,
|
||||
"save_last": True,
|
||||
}
|
||||
}
|
||||
if hasattr(model, "monitor"):
|
||||
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
||||
default_modelckpt_cfg["params"]["save_top_k"] = 3
|
||||
default_modelckpt_cfg["monitor"] = model.monitor
|
||||
default_modelckpt_cfg["save_top_k"] = 3
|
||||
|
||||
if "modelcheckpoint" in lightning_config:
|
||||
modelckpt_cfg = lightning_config.modelcheckpoint
|
||||
modelckpt_cfg = lightning_config.modelcheckpoint["params"]
|
||||
else:
|
||||
modelckpt_cfg = OmegaConf.create()
|
||||
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
||||
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
||||
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
||||
trainer_kwargs["checkpoint_callback"] = ModelCheckpoint(**modelckpt_cfg)
|
||||
|
||||
# Set up various callbacks, including logging, learning rate monitoring, and CUDA management
|
||||
# add callback which sets up log directory
|
||||
default_callbacks_cfg = {
|
||||
"setup_callback": { # callback to set up the training
|
||||
"target": "main.SetupCallback",
|
||||
"params": {
|
||||
"resume": opt.resume, # resume training if applicable
|
||||
"now": now,
|
||||
"logdir": logdir, # directory to save the log file
|
||||
"ckptdir": ckptdir, # directory to save the checkpoint file
|
||||
"cfgdir": cfgdir, # directory to save the configuration file
|
||||
"config": config, # configuration dictionary
|
||||
"lightning_config": lightning_config, # LightningModule configuration
|
||||
}
|
||||
},
|
||||
"image_logger": { # callback to log image data
|
||||
"target": "main.ImageLogger",
|
||||
"params": {
|
||||
"batch_frequency": 750, # how frequently to log images
|
||||
"max_images": 4, # maximum number of images to log
|
||||
"clamp": True # whether to clamp pixel values to [0,1]
|
||||
}
|
||||
},
|
||||
"learning_rate_logger": { # callback to log learning rate
|
||||
"target": "main.LearningRateMonitor",
|
||||
"params": {
|
||||
"logging_interval": "step", # logging frequency (either 'step' or 'epoch')
|
||||
# "log_momentum": True # whether to log momentum (currently commented out)
|
||||
}
|
||||
},
|
||||
"cuda_callback": { # callback to handle CUDA-related operations
|
||||
"target": "main.CUDACallback"
|
||||
},
|
||||
}
|
||||
#Create an empty OmegaConf configuration object
|
||||
|
||||
# If the LightningModule configuration has specified callbacks, use those
|
||||
# Otherwise, create an empty OmegaConf configuration object
|
||||
if "callbacks" in lightning_config:
|
||||
callbacks_cfg = lightning_config.callbacks
|
||||
else:
|
||||
callbacks_cfg = OmegaConf.create()
|
||||
|
||||
# If the 'metrics_over_trainsteps_checkpoint' callback is specified in the
|
||||
# LightningModule configuration, update the default callbacks configuration
|
||||
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
|
||||
print(
|
||||
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
|
||||
default_metrics_over_trainsteps_ckpt_dict = {
|
||||
'metrics_over_trainsteps_checkpoint': {
|
||||
"target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint',
|
||||
'params': {
|
||||
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
||||
"filename": "{epoch:06}-{step:09}",
|
||||
"verbose": True,
|
||||
'save_top_k': -1,
|
||||
'every_n_train_steps': 10000,
|
||||
'save_weights_only': True
|
||||
}
|
||||
}
|
||||
callbacks_cfg = OmegaConf.create()
|
||||
|
||||
#Instantiate items according to the configs
|
||||
trainer_kwargs.setdefault("callbacks", [])
|
||||
setup_callback_config = {
|
||||
"resume": opt.resume, # resume training if applicable
|
||||
"now": now,
|
||||
"logdir": logdir, # directory to save the log file
|
||||
"ckptdir": ckptdir, # directory to save the checkpoint file
|
||||
"cfgdir": cfgdir, # directory to save the configuration file
|
||||
"config": config, # configuration dictionary
|
||||
"lightning_config": lightning_config, # LightningModule configuration
|
||||
}
|
||||
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
||||
trainer_kwargs["callbacks"].append(SetupCallback(**setup_callback_config))
|
||||
|
||||
# Merge the default callbacks configuration with the specified callbacks configuration, and instantiate the callbacks
|
||||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
||||
|
||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||
image_logger_config = {
|
||||
|
||||
"batch_frequency": 750, # how frequently to log images
|
||||
"max_images": 4, # maximum number of images to log
|
||||
"clamp": True # whether to clamp pixel values to [0,1]
|
||||
}
|
||||
trainer_kwargs["callbacks"].append(ImageLogger(**image_logger_config))
|
||||
|
||||
learning_rate_logger_config = {
|
||||
"logging_interval": "step", # logging frequency (either 'step' or 'epoch')
|
||||
# "log_momentum": True # whether to log momentum (currently commented out)
|
||||
}
|
||||
trainer_kwargs["callbacks"].append(LearningRateMonitor(**learning_rate_logger_config))
|
||||
|
||||
metrics_over_trainsteps_checkpoint_config= {
|
||||
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
||||
"filename": "{epoch:06}-{step:09}",
|
||||
"verbose": True,
|
||||
'save_top_k': -1,
|
||||
'every_n_train_steps': 10000,
|
||||
'save_weights_only': True
|
||||
}
|
||||
trainer_kwargs["callbacks"].append(ModelCheckpoint(**metrics_over_trainsteps_checkpoint_config))
|
||||
trainer_kwargs["callbacks"].append(CUDACallback())
|
||||
|
||||
# Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory
|
||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||
trainer.logdir = logdir
|
||||
|
||||
# Create a data module based on the configuration file
|
||||
data = instantiate_from_config(config.data)
|
||||
data = DataModuleFromConfig(**config.data)
|
||||
|
||||
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||
# calling these ourselves should not be necessary but it is.
|
||||
# lightning still takes care of proper multiprocessing though
|
||||
@@ -846,7 +800,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Configure learning rate based on the batch size, base learning rate and number of GPUs
|
||||
# If scale_lr is true, calculate the learning rate based on additional factors
|
||||
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
||||
bs, base_lr = config.data.batch_size, config.model.base_learning_rate
|
||||
if not cpu:
|
||||
ngpu = trainer_config["devices"]
|
||||
else:
|
||||
|
Reference in New Issue
Block a user