mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
Polish Code
This commit is contained in:
@@ -1,16 +1,13 @@
|
||||
import torch
|
||||
try:
|
||||
import lightning.pytorch as pl
|
||||
except:
|
||||
import pytorch_lightning as pl
|
||||
import lightning.pytorch as pl
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import Identity
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.modules.ema import LitEma
|
||||
|
||||
|
||||
@@ -32,7 +29,7 @@ class AutoencoderKL(pl.LightningModule):
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.loss = Identity()
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
|
Reference in New Issue
Block a user