mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
update lightning version (#1954)
This commit is contained in:
@@ -3,23 +3,23 @@ import numpy as np
|
||||
import time
|
||||
import torch
|
||||
import torchvision
|
||||
import pytorch_lightning as pl
|
||||
import lightning.pytorch as pl
|
||||
|
||||
from packaging import version
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.data import random_split, DataLoader, Dataset, Subset
|
||||
from functools import partial
|
||||
from PIL import Image
|
||||
# from pytorch_lightning.strategies.colossalai import ColossalAIStrategy
|
||||
# from lightning.pytorch.strategies.colossalai import ColossalAIStrategy
|
||||
# from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from lightning.pytorch import seed_everything
|
||||
from lightning.pytorch.trainer import Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_only
|
||||
from lightning.pytorch.utilities import rank_zero_info
|
||||
from diffusers.models.unet_2d import UNet2DModel
|
||||
|
||||
from clip.model import Bottleneck
|
||||
@@ -610,7 +610,7 @@ if __name__ == "__main__":
|
||||
# default logger configs
|
||||
default_logger_cfgs = {
|
||||
"wandb": {
|
||||
"target": "pytorch_lightning.loggers.WandbLogger",
|
||||
"target": "lightning.pytorch.loggers.WandbLogger",
|
||||
"params": {
|
||||
"name": nowname,
|
||||
"save_dir": logdir,
|
||||
@@ -619,7 +619,7 @@ if __name__ == "__main__":
|
||||
}
|
||||
},
|
||||
"tensorboard":{
|
||||
"target": "pytorch_lightning.loggers.TensorBoardLogger",
|
||||
"target": "lightning.pytorch.loggers.TensorBoardLogger",
|
||||
"params":{
|
||||
"save_dir": logdir,
|
||||
"name": "diff_tb",
|
||||
@@ -642,7 +642,7 @@ if __name__ == "__main__":
|
||||
print("Using strategy: {}".format(strategy_cfg["target"]))
|
||||
else:
|
||||
strategy_cfg = {
|
||||
"target": "pytorch_lightning.strategies.DDPStrategy",
|
||||
"target": "lightning.pytorch.strategies.DDPStrategy",
|
||||
"params": {
|
||||
"find_unused_parameters": False
|
||||
}
|
||||
@@ -654,7 +654,7 @@ if __name__ == "__main__":
|
||||
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
||||
# specify which metric is used to determine best models
|
||||
default_modelckpt_cfg = {
|
||||
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
||||
"target": "lightning.pytorch.callbacks.ModelCheckpoint",
|
||||
"params": {
|
||||
"dirpath": ckptdir,
|
||||
"filename": "{epoch:06}",
|
||||
@@ -722,7 +722,7 @@ if __name__ == "__main__":
|
||||
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
|
||||
default_metrics_over_trainsteps_ckpt_dict = {
|
||||
'metrics_over_trainsteps_checkpoint':
|
||||
{"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
|
||||
{"target": 'lightning.pytorch.callbacks.ModelCheckpoint',
|
||||
'params': {
|
||||
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
||||
"filename": "{epoch:06}-{step:09}",
|
||||
|
Reference in New Issue
Block a user