Polish Code

This commit is contained in:
natalie_cao
2023-04-11 14:10:45 +08:00
committed by アマデウス
parent 152239bbfa
commit de84c0311a
15 changed files with 562 additions and 719 deletions

View File

@@ -1,16 +1,13 @@
import torch
try:
import lightning.pytorch as pl
except:
import pytorch_lightning as pl
import lightning.pytorch as pl
import torch.nn.functional as F
from torch import nn
from torch.nn import functional as F
from torch.nn import Identity
from contextlib import contextmanager
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config
from ldm.modules.ema import LitEma
@@ -32,7 +29,7 @@ class AutoencoderKL(pl.LightningModule):
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.loss = Identity()
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)