mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
This commit is contained in:
@@ -23,21 +23,19 @@ from packaging import version
|
||||
from PIL import Image
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
from torch.utils.data import DataLoader, Dataset, Subset, random_split
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
#try:
|
||||
from lightning.pytorch import seed_everything
|
||||
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
|
||||
from lightning.pytorch.trainer import Trainer
|
||||
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
|
||||
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger
|
||||
from lightning.pytorch.strategies import ColossalAIStrategy,DDPStrategy
|
||||
LIGHTNING_PACK_NAME = "lightning.pytorch."
|
||||
# #except:
|
||||
# from pytorch_lightning import seed_everything
|
||||
# from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
|
||||
# from pytorch_lightning.trainer import Trainer
|
||||
# from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
# LIGHTNING_PACK_NAME = "pytorch_lightning."
|
||||
|
||||
try:
|
||||
from lightning.pytorch import seed_everything
|
||||
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
|
||||
from lightning.pytorch.trainer import Trainer
|
||||
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
|
||||
LIGHTNING_PACK_NAME = "lightning.pytorch."
|
||||
except:
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
LIGHTNING_PACK_NAME = "pytorch_lightning."
|
||||
|
||||
from ldm.data.base import Txt2ImgIterableBaseDataset
|
||||
from ldm.util import instantiate_from_config
|
||||
@@ -577,7 +575,7 @@ if __name__ == "__main__":
|
||||
# target: path to test dataset
|
||||
# params:
|
||||
# key: value
|
||||
# lightning: (optional, has same defaults and can be specified on cmdline)
|
||||
# lightning: (optional, has sane defaults and can be specified on cmdline)
|
||||
# trainer:
|
||||
# additional arguments to trainer
|
||||
# logger:
|
||||
@@ -655,7 +653,7 @@ if __name__ == "__main__":
|
||||
# Sets the seed for the random number generator to ensure reproducibility
|
||||
seed_everything(opt.seed)
|
||||
|
||||
# Intinalize and save configuration using the OmegaConf library.
|
||||
# Intinalize and save configuratioon using teh OmegaConf library.
|
||||
try:
|
||||
# init and save configs
|
||||
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
||||
@@ -689,7 +687,7 @@ if __name__ == "__main__":
|
||||
config.model["params"].update({"ckpt": ckpt})
|
||||
rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
|
||||
|
||||
model = LatentDiffusion(**config.model.get("params", dict()))
|
||||
model = instantiate_from_config(config.model)
|
||||
# trainer and callbacks
|
||||
trainer_kwargs = dict()
|
||||
|
||||
@@ -698,7 +696,7 @@ if __name__ == "__main__":
|
||||
# These loggers are specified as targets in the dictionary, along with the configuration settings specific to each logger.
|
||||
default_logger_cfgs = {
|
||||
"wandb": {
|
||||
#"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger",
|
||||
"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger",
|
||||
"params": {
|
||||
"name": nowname,
|
||||
"save_dir": logdir,
|
||||
@@ -707,7 +705,7 @@ if __name__ == "__main__":
|
||||
}
|
||||
},
|
||||
"tensorboard": {
|
||||
#"target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger",
|
||||
"target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger",
|
||||
"params": {
|
||||
"save_dir": logdir,
|
||||
"name": "diff_tb",
|
||||
@@ -720,32 +718,30 @@ if __name__ == "__main__":
|
||||
default_logger_cfg = default_logger_cfgs["tensorboard"]
|
||||
if "logger" in lightning_config:
|
||||
logger_cfg = lightning_config.logger
|
||||
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
||||
trainer_kwargs["logger"] = WandbLogger(**logger_cfg.get("params", dict()))
|
||||
else:
|
||||
logger_cfg = default_logger_cfg
|
||||
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
||||
trainer_kwargs["logger"] = TensorBoardLogger(**logger_cfg.get("params", dict()))
|
||||
|
||||
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
||||
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
||||
|
||||
# config the strategy, defualt is ddp
|
||||
if "strategy" in trainer_config:
|
||||
strategy_cfg = trainer_config["strategy"]
|
||||
trainer_kwargs["strategy"] = ColossalAIStrategy(**strategy_cfg.get("params", dict()))
|
||||
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
|
||||
else:
|
||||
strategy_cfg = {
|
||||
#"target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy",
|
||||
"target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy",
|
||||
"params": {
|
||||
"find_unused_parameters": False
|
||||
}
|
||||
}
|
||||
trainer_kwargs["strategy"] = DDPStrategy(**strategy_cfg.get("params", dict()))
|
||||
|
||||
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
|
||||
|
||||
# Set up ModelCheckpoint callback to save best models
|
||||
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
||||
# specify which metric is used to determine best models
|
||||
default_modelckpt_cfg = {
|
||||
#"target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint",
|
||||
"target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint",
|
||||
"params": {
|
||||
"dirpath": ckptdir,
|
||||
"filename": "{epoch:06}",
|
||||
@@ -763,13 +759,13 @@ if __name__ == "__main__":
|
||||
modelckpt_cfg = OmegaConf.create()
|
||||
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
||||
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
||||
trainer_kwargs["checkpoint_callback"] = ModelCheckpoint(**modelckpt_cfg.get("params", dict()))
|
||||
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
||||
|
||||
# Set up various callbacks, including logging, learning rate monitoring, and CUDA management
|
||||
# add callback which sets up log directory
|
||||
default_callbacks_cfg = {
|
||||
"setup_callback": { # callback to set up the training
|
||||
#"target": "main.SetupCallback",
|
||||
"target": "main.SetupCallback",
|
||||
"params": {
|
||||
"resume": opt.resume, # resume training if applicable
|
||||
"now": now,
|
||||
@@ -781,7 +777,7 @@ if __name__ == "__main__":
|
||||
}
|
||||
},
|
||||
"image_logger": { # callback to log image data
|
||||
#"target": "main.ImageLogger",
|
||||
"target": "main.ImageLogger",
|
||||
"params": {
|
||||
"batch_frequency": 750, # how frequently to log images
|
||||
"max_images": 4, # maximum number of images to log
|
||||
@@ -789,14 +785,14 @@ if __name__ == "__main__":
|
||||
}
|
||||
},
|
||||
"learning_rate_logger": { # callback to log learning rate
|
||||
#"target": "main.LearningRateMonitor",
|
||||
"target": "main.LearningRateMonitor",
|
||||
"params": {
|
||||
"logging_interval": "step", # logging frequency (either 'step' or 'epoch')
|
||||
# "log_momentum": True # whether to log momentum (currently commented out)
|
||||
}
|
||||
},
|
||||
"cuda_callback": { # callback to handle CUDA-related operations
|
||||
#"target": "main.CUDACallback"
|
||||
"target": "main.CUDACallback"
|
||||
},
|
||||
}
|
||||
|
||||
@@ -814,7 +810,7 @@ if __name__ == "__main__":
|
||||
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
|
||||
default_metrics_over_trainsteps_ckpt_dict = {
|
||||
'metrics_over_trainsteps_checkpoint': {
|
||||
#"target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint',
|
||||
"target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint',
|
||||
'params': {
|
||||
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
||||
"filename": "{epoch:06}-{step:09}",
|
||||
@@ -829,35 +825,15 @@ if __name__ == "__main__":
|
||||
|
||||
# Merge the default callbacks configuration with the specified callbacks configuration, and instantiate the callbacks
|
||||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
||||
|
||||
#Instantiate items according to the configs
|
||||
trainer_kwargs.setdefault("callbacks", [])
|
||||
|
||||
if "setup_callback" in callbacks_cfg:
|
||||
setup_callback_config = callbacks_cfg["setup_callback"]
|
||||
trainer_kwargs["callbacks"].append(SetupCallback(**setup_callback_config.get("params", dict())))
|
||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||
|
||||
if "image_logger" in callbacks_cfg:
|
||||
image_logger_config = callbacks_cfg["image_logger"]
|
||||
trainer_kwargs["callbacks"].append(ImageLogger(**image_logger_config.get("params", dict())))
|
||||
|
||||
if "learning_rate_logger" in callbacks_cfg:
|
||||
learning_rate_logger_config = callbacks_cfg["learning_rate_logger"]
|
||||
trainer_kwargs["callbacks"].append(LearningRateMonitor(**learning_rate_logger_config.get("params", dict())))
|
||||
|
||||
if "cuda_callback" in callbacks_cfg:
|
||||
cuda_callback_config = callbacks_cfg["cuda_callback"]
|
||||
trainer_kwargs["callbacks"].append(CUDACallback(**cuda_callback_config.get("params", dict())))
|
||||
|
||||
if "metrics_over_trainsteps_checkpoint" in callbacks_cfg:
|
||||
metrics_over_config = callbacks_cfg['metrics_over_trainsteps_checkpoint']
|
||||
trainer_kwargs["callbacks"].append(ModelCheckpoint(**metrics_over_config.get("params", dict())))
|
||||
#trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||
# Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory
|
||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||
trainer.logdir = logdir
|
||||
|
||||
|
||||
# Create a data module based on the configuration file
|
||||
data = DataModuleFromConfig(**config.data.get("params", dict()))
|
||||
data = instantiate_from_config(config.data)
|
||||
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||
# calling these ourselves should not be necessary but it is.
|
||||
# lightning still takes care of proper multiprocessing though
|
||||
|
Reference in New Issue
Block a user