mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 21:17:08 +00:00
* Update requirements.txt * Update environment.yaml * Update README.md * Update environment.yaml * Update README.md * Update README.md * Delete requirements_colossalai.txt * Update requirements.txt * Update README.md
This commit is contained in:
@@ -6,11 +6,10 @@ except:
|
||||
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
from torch.nn import Identity
|
||||
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.modules.ema import LitEma
|
||||
|
||||
|
||||
@@ -32,7 +31,7 @@ class AutoencoderKL(pl.LightningModule):
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.loss = Identity(**lossconfig.get("params", dict()))
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
|
Reference in New Issue
Block a user