mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-05 12:24:38 +00:00
support stable diffusion v2
This commit is contained in:
@@ -1,54 +1,48 @@
|
||||
import argparse, os, sys, datetime, glob, importlib, csv
|
||||
import numpy as np
|
||||
import argparse
|
||||
import csv
|
||||
import datetime
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import lightning.pytorch as pl
|
||||
|
||||
from packaging import version
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.data import random_split, DataLoader, Dataset, Subset
|
||||
try:
|
||||
import lightning.pytorch as pl
|
||||
except:
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from functools import partial
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
# from lightning.pytorch.strategies.colossalai import ColossalAIStrategy
|
||||
# from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
from torch.utils.data import DataLoader, Dataset, Subset, random_split
|
||||
|
||||
from lightning.pytorch import seed_everything
|
||||
from lightning.pytorch.trainer import Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_only
|
||||
from lightning.pytorch.utilities import rank_zero_info
|
||||
from diffusers.models.unet_2d import UNet2DModel
|
||||
|
||||
from clip.model import Bottleneck
|
||||
from transformers.models.clip.modeling_clip import CLIPTextTransformer
|
||||
try:
|
||||
from lightning.pytorch import seed_everything
|
||||
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
|
||||
from lightning.pytorch.trainer import Trainer
|
||||
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
|
||||
LIGHTNING_PACK_NAME = "lightning.pytorch."
|
||||
except:
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
LIGHTNING_PACK_NAME = "pytorch_lightning."
|
||||
|
||||
from ldm.data.base import Txt2ImgIterableBaseDataset
|
||||
from ldm.util import instantiate_from_config
|
||||
import clip
|
||||
from einops import rearrange, repeat
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
import kornia
|
||||
|
||||
from ldm.modules.x_transformer import *
|
||||
from ldm.modules.encoders.modules import *
|
||||
from taming.modules.diffusionmodules.model import ResnetBlock
|
||||
from taming.modules.transformer.mingpt import *
|
||||
from taming.modules.transformer.permuter import *
|
||||
# from ldm.modules.attention import enable_flash_attentions
|
||||
|
||||
|
||||
from ldm.modules.ema import LitEma
|
||||
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
||||
from ldm.models.autoencoder import AutoencoderKL
|
||||
from ldm.models.autoencoder import *
|
||||
from ldm.models.diffusion.ddim import *
|
||||
from ldm.modules.diffusionmodules.openaimodel import *
|
||||
from ldm.modules.diffusionmodules.model import *
|
||||
from ldm.modules.diffusionmodules.model import Decoder, Encoder, Up_module, Down_module, Mid_module, temb_module
|
||||
from ldm.modules.attention import enable_flash_attention
|
||||
|
||||
class DataLoaderX(DataLoader):
|
||||
|
||||
def __iter__(self):
|
||||
@@ -56,6 +50,7 @@ class DataLoaderX(DataLoader):
|
||||
|
||||
|
||||
def get_parser(**parser_kwargs):
|
||||
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
@@ -91,7 +86,7 @@ def get_parser(**parser_kwargs):
|
||||
nargs="*",
|
||||
metavar="base_config.yaml",
|
||||
help="paths to base configs. Loaded from left-to-right. "
|
||||
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
||||
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
||||
default=list(),
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -111,11 +106,7 @@ def get_parser(**parser_kwargs):
|
||||
nargs="?",
|
||||
help="disable test",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--project",
|
||||
help="name of new or path to existing project"
|
||||
)
|
||||
parser.add_argument("-p", "--project", help="name of new or path to existing project")
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--debug",
|
||||
@@ -210,8 +201,17 @@ def worker_init_fn(_):
|
||||
|
||||
|
||||
class DataModuleFromConfig(pl.LightningDataModule):
|
||||
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
|
||||
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
|
||||
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
train=None,
|
||||
validation=None,
|
||||
test=None,
|
||||
predict=None,
|
||||
wrap=False,
|
||||
num_workers=None,
|
||||
shuffle_test_loader=False,
|
||||
use_worker_init_fn=False,
|
||||
shuffle_val_dataloader=False):
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
@@ -237,9 +237,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
||||
instantiate_from_config(data_cfg)
|
||||
|
||||
def setup(self, stage=None):
|
||||
self.datasets = dict(
|
||||
(k, instantiate_from_config(self.dataset_configs[k]))
|
||||
for k in self.dataset_configs)
|
||||
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
|
||||
if self.wrap:
|
||||
for k in self.datasets:
|
||||
self.datasets[k] = WrappedDataset(self.datasets[k])
|
||||
@@ -250,9 +248,11 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoaderX(self.datasets["train"], batch_size=self.batch_size,
|
||||
num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
|
||||
worker_init_fn=init_fn)
|
||||
return DataLoaderX(self.datasets["train"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False if is_iterable_dataset else True,
|
||||
worker_init_fn=init_fn)
|
||||
|
||||
def _val_dataloader(self, shuffle=False):
|
||||
if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
||||
@@ -260,10 +260,10 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoaderX(self.datasets["validation"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
shuffle=shuffle)
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
shuffle=shuffle)
|
||||
|
||||
def _test_dataloader(self, shuffle=False):
|
||||
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
||||
@@ -275,19 +275,25 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
||||
# do not shuffle dataloader for iterable dataset
|
||||
shuffle = shuffle and (not is_iterable_dataset)
|
||||
|
||||
return DataLoaderX(self.datasets["test"], batch_size=self.batch_size,
|
||||
num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
|
||||
return DataLoaderX(self.datasets["test"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
shuffle=shuffle)
|
||||
|
||||
def _predict_dataloader(self, shuffle=False):
|
||||
if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoaderX(self.datasets["predict"], batch_size=self.batch_size,
|
||||
num_workers=self.num_workers, worker_init_fn=init_fn)
|
||||
return DataLoaderX(self.datasets["predict"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn)
|
||||
|
||||
|
||||
class SetupCallback(Callback):
|
||||
|
||||
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
||||
super().__init__()
|
||||
self.resume = resume
|
||||
@@ -317,8 +323,7 @@ class SetupCallback(Callback):
|
||||
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
|
||||
print("Project config")
|
||||
print(OmegaConf.to_yaml(self.config))
|
||||
OmegaConf.save(self.config,
|
||||
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
||||
OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
||||
|
||||
print("Lightning config")
|
||||
print(OmegaConf.to_yaml(self.lightning_config))
|
||||
@@ -338,8 +343,16 @@ class SetupCallback(Callback):
|
||||
|
||||
|
||||
class ImageLogger(Callback):
|
||||
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
|
||||
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
|
||||
|
||||
def __init__(self,
|
||||
batch_frequency,
|
||||
max_images,
|
||||
clamp=True,
|
||||
increase_log_steps=True,
|
||||
rescale=True,
|
||||
disabled=False,
|
||||
log_on_batch_idx=False,
|
||||
log_first_step=False,
|
||||
log_images_kwargs=None):
|
||||
super().__init__()
|
||||
self.rescale = rescale
|
||||
@@ -348,7 +361,7 @@ class ImageLogger(Callback):
|
||||
self.logger_log_images = {
|
||||
pl.loggers.CSVLogger: self._testtube,
|
||||
}
|
||||
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
||||
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
||||
if not increase_log_steps:
|
||||
self.log_steps = [self.batch_freq]
|
||||
self.clamp = clamp
|
||||
@@ -361,39 +374,30 @@ class ImageLogger(Callback):
|
||||
def _testtube(self, pl_module, images, batch_idx, split):
|
||||
for k in images:
|
||||
grid = torchvision.utils.make_grid(images[k])
|
||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||
|
||||
tag = f"{split}/{k}"
|
||||
pl_module.logger.experiment.add_image(
|
||||
tag, grid,
|
||||
global_step=pl_module.global_step)
|
||||
pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step)
|
||||
|
||||
@rank_zero_only
|
||||
def log_local(self, save_dir, split, images,
|
||||
global_step, current_epoch, batch_idx):
|
||||
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
|
||||
root = os.path.join(save_dir, "images", split)
|
||||
for k in images:
|
||||
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
||||
if self.rescale:
|
||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
||||
grid = grid.numpy()
|
||||
grid = (grid * 255).astype(np.uint8)
|
||||
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
||||
k,
|
||||
global_step,
|
||||
current_epoch,
|
||||
batch_idx)
|
||||
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
|
||||
path = os.path.join(root, filename)
|
||||
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
||||
Image.fromarray(grid).save(path)
|
||||
|
||||
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
||||
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
||||
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
||||
hasattr(pl_module, "log_images") and
|
||||
callable(pl_module.log_images) and
|
||||
self.max_images > 0):
|
||||
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
||||
hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0):
|
||||
logger = type(pl_module.logger)
|
||||
|
||||
is_train = pl_module.training
|
||||
@@ -411,8 +415,8 @@ class ImageLogger(Callback):
|
||||
if self.clamp:
|
||||
images[k] = torch.clamp(images[k], -1., 1.)
|
||||
|
||||
self.log_local(pl_module.logger.save_dir, split, images,
|
||||
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
||||
self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch,
|
||||
batch_idx)
|
||||
|
||||
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
||||
logger_log_images(pl_module, images, pl_module.global_step, split)
|
||||
@@ -421,8 +425,8 @@ class ImageLogger(Callback):
|
||||
pl_module.train()
|
||||
|
||||
def check_frequency(self, check_idx):
|
||||
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
|
||||
check_idx > 0 or self.log_first_step):
|
||||
if ((check_idx % self.batch_freq) == 0 or
|
||||
(check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step):
|
||||
try:
|
||||
self.log_steps.pop(0)
|
||||
except IndexError as e:
|
||||
@@ -461,7 +465,7 @@ class CUDACallback(Callback):
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
torch.cuda.synchronize(trainer.strategy.root_device.index)
|
||||
max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2 ** 20
|
||||
max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20
|
||||
epoch_time = time.time() - self.start_time
|
||||
|
||||
try:
|
||||
@@ -528,13 +532,9 @@ if __name__ == "__main__":
|
||||
|
||||
opt, unknown = parser.parse_known_args()
|
||||
if opt.name and opt.resume:
|
||||
raise ValueError(
|
||||
"-n/--name and -r/--resume cannot be specified both."
|
||||
"If you want to resume training in a new log folder, "
|
||||
"use -n/--name in combination with --resume_from_checkpoint"
|
||||
)
|
||||
if opt.flash:
|
||||
enable_flash_attention()
|
||||
raise ValueError("-n/--name and -r/--resume cannot be specified both."
|
||||
"If you want to resume training in a new log folder, "
|
||||
"use -n/--name in combination with --resume_from_checkpoint")
|
||||
if opt.resume:
|
||||
if not os.path.exists(opt.resume):
|
||||
raise ValueError("Cannot find {}".format(opt.resume))
|
||||
@@ -578,7 +578,7 @@ if __name__ == "__main__":
|
||||
lightning_config = config.pop("lightning", OmegaConf.create())
|
||||
# merge trainer cli with config
|
||||
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
||||
|
||||
|
||||
for k in nondefault_trainer_args(opt):
|
||||
trainer_config[k] = getattr(opt, k)
|
||||
|
||||
@@ -601,7 +601,7 @@ if __name__ == "__main__":
|
||||
else:
|
||||
config.model["params"].update({"use_fp16": False})
|
||||
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
|
||||
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
# trainer and callbacks
|
||||
trainer_kwargs = dict()
|
||||
@@ -610,7 +610,7 @@ if __name__ == "__main__":
|
||||
# default logger configs
|
||||
default_logger_cfgs = {
|
||||
"wandb": {
|
||||
"target": "lightning.pytorch.loggers.WandbLogger",
|
||||
"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger",
|
||||
"params": {
|
||||
"name": nowname,
|
||||
"save_dir": logdir,
|
||||
@@ -618,9 +618,9 @@ if __name__ == "__main__":
|
||||
"id": nowname,
|
||||
}
|
||||
},
|
||||
"tensorboard":{
|
||||
"target": "lightning.pytorch.loggers.TensorBoardLogger",
|
||||
"params":{
|
||||
"tensorboard": {
|
||||
"target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger",
|
||||
"params": {
|
||||
"save_dir": logdir,
|
||||
"name": "diff_tb",
|
||||
"log_graph": True
|
||||
@@ -640,9 +640,10 @@ if __name__ == "__main__":
|
||||
if "strategy" in trainer_config:
|
||||
strategy_cfg = trainer_config["strategy"]
|
||||
print("Using strategy: {}".format(strategy_cfg["target"]))
|
||||
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
|
||||
else:
|
||||
strategy_cfg = {
|
||||
"target": "lightning.pytorch.strategies.DDPStrategy",
|
||||
"target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy",
|
||||
"params": {
|
||||
"find_unused_parameters": False
|
||||
}
|
||||
@@ -654,7 +655,7 @@ if __name__ == "__main__":
|
||||
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
||||
# specify which metric is used to determine best models
|
||||
default_modelckpt_cfg = {
|
||||
"target": "lightning.pytorch.callbacks.ModelCheckpoint",
|
||||
"target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint",
|
||||
"params": {
|
||||
"dirpath": ckptdir,
|
||||
"filename": "{epoch:06}",
|
||||
@@ -670,7 +671,7 @@ if __name__ == "__main__":
|
||||
if "modelcheckpoint" in lightning_config:
|
||||
modelckpt_cfg = lightning_config.modelcheckpoint
|
||||
else:
|
||||
modelckpt_cfg = OmegaConf.create()
|
||||
modelckpt_cfg = OmegaConf.create()
|
||||
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
||||
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
|
||||
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
||||
@@ -702,7 +703,7 @@ if __name__ == "__main__":
|
||||
"target": "main.LearningRateMonitor",
|
||||
"params": {
|
||||
"logging_interval": "step",
|
||||
# "log_momentum": True
|
||||
# "log_momentum": True
|
||||
}
|
||||
},
|
||||
"cuda_callback": {
|
||||
@@ -721,17 +722,17 @@ if __name__ == "__main__":
|
||||
print(
|
||||
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
|
||||
default_metrics_over_trainsteps_ckpt_dict = {
|
||||
'metrics_over_trainsteps_checkpoint':
|
||||
{"target": 'lightning.pytorch.callbacks.ModelCheckpoint',
|
||||
'params': {
|
||||
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
||||
"filename": "{epoch:06}-{step:09}",
|
||||
"verbose": True,
|
||||
'save_top_k': -1,
|
||||
'every_n_train_steps': 10000,
|
||||
'save_weights_only': True
|
||||
}
|
||||
}
|
||||
'metrics_over_trainsteps_checkpoint': {
|
||||
"target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint',
|
||||
'params': {
|
||||
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
||||
"filename": "{epoch:06}-{step:09}",
|
||||
"verbose": True,
|
||||
'save_top_k': -1,
|
||||
'every_n_train_steps': 10000,
|
||||
'save_weights_only': True
|
||||
}
|
||||
}
|
||||
}
|
||||
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
||||
|
||||
@@ -744,7 +745,7 @@ if __name__ == "__main__":
|
||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||
|
||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||
trainer.logdir = logdir ###
|
||||
trainer.logdir = logdir ###
|
||||
|
||||
# data
|
||||
data = instantiate_from_config(config.data)
|
||||
@@ -772,14 +773,13 @@ if __name__ == "__main__":
|
||||
if opt.scale_lr:
|
||||
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
||||
print(
|
||||
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
|
||||
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
|
||||
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
|
||||
.format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
|
||||
else:
|
||||
model.learning_rate = base_lr
|
||||
print("++++ NOT USING LR SCALING ++++")
|
||||
print(f"Setting learning rate to {model.learning_rate:.2e}")
|
||||
|
||||
|
||||
# allow checkpointing via USR1
|
||||
def melk(*args, **kwargs):
|
||||
# run all checkpoint hooks
|
||||
@@ -788,13 +788,11 @@ if __name__ == "__main__":
|
||||
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
|
||||
def divein(*args, **kwargs):
|
||||
if trainer.global_rank == 0:
|
||||
import pudb;
|
||||
import pudb
|
||||
pudb.set_trace()
|
||||
|
||||
|
||||
import signal
|
||||
|
||||
signal.signal(signal.SIGUSR1, melk)
|
||||
@@ -803,8 +801,6 @@ if __name__ == "__main__":
|
||||
# run
|
||||
if opt.train:
|
||||
try:
|
||||
for name, m in model.named_parameters():
|
||||
print(name)
|
||||
trainer.fit(model, data)
|
||||
except Exception:
|
||||
melk()
|
||||
|
||||
Reference in New Issue
Block a user