update lightning version (#1954)

This commit is contained in:
Fazzie-Maqianli
2022-11-15 16:57:48 +08:00
committed by GitHub
parent 52c6ad26e0
commit 6bdd0a90ca
13 changed files with 29 additions and 35 deletions

View File

@@ -3,23 +3,23 @@ import numpy as np
import time
import torch
import torchvision
import pytorch_lightning as pl
import lightning.pytorch as pl
from packaging import version
from omegaconf import OmegaConf
from torch.utils.data import random_split, DataLoader, Dataset, Subset
from functools import partial
from PIL import Image
# from pytorch_lightning.strategies.colossalai import ColossalAIStrategy
# from lightning.pytorch.strategies.colossalai import ColossalAIStrategy
# from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from prefetch_generator import BackgroundGenerator
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from pytorch_lightning.utilities import rank_zero_info
from lightning.pytorch import seed_everything
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from lightning.pytorch.utilities import rank_zero_info
from diffusers.models.unet_2d import UNet2DModel
from clip.model import Bottleneck
@@ -610,7 +610,7 @@ if __name__ == "__main__":
# default logger configs
default_logger_cfgs = {
"wandb": {
"target": "pytorch_lightning.loggers.WandbLogger",
"target": "lightning.pytorch.loggers.WandbLogger",
"params": {
"name": nowname,
"save_dir": logdir,
@@ -619,7 +619,7 @@ if __name__ == "__main__":
}
},
"tensorboard":{
"target": "pytorch_lightning.loggers.TensorBoardLogger",
"target": "lightning.pytorch.loggers.TensorBoardLogger",
"params":{
"save_dir": logdir,
"name": "diff_tb",
@@ -642,7 +642,7 @@ if __name__ == "__main__":
print("Using strategy: {}".format(strategy_cfg["target"]))
else:
strategy_cfg = {
"target": "pytorch_lightning.strategies.DDPStrategy",
"target": "lightning.pytorch.strategies.DDPStrategy",
"params": {
"find_unused_parameters": False
}
@@ -654,7 +654,7 @@ if __name__ == "__main__":
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg = {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"target": "lightning.pytorch.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch:06}",
@@ -722,7 +722,7 @@ if __name__ == "__main__":
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
default_metrics_over_trainsteps_ckpt_dict = {
'metrics_over_trainsteps_checkpoint':
{"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
{"target": 'lightning.pytorch.callbacks.ModelCheckpoint',
'params': {
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
"filename": "{epoch:06}-{step:09}",