Polish Code

This commit is contained in:
natalie_cao
2023-04-11 14:10:45 +08:00
committed by アマデウス
parent 152239bbfa
commit de84c0311a
15 changed files with 562 additions and 719 deletions

View File

@@ -7,8 +7,9 @@ from datetime import datetime
from diffusers import StableDiffusionPipeline
import torch
from ldm.util import instantiate_from_config
from main import get_parser
from ldm.modules.diffusionmodules.openaimodel import UNetModel
if __name__ == "__main__":
with torch.no_grad():
@@ -17,7 +18,7 @@ if __name__ == "__main__":
config = f.read()
base_config = yaml.load(config, Loader=yaml.FullLoader)
unet_config = base_config['model']['params']['unet_config']
diffusion_model = instantiate_from_config(unet_config).to("cuda:0")
diffusion_model = UNetModel(**unet_config).to("cuda:0")
pipe = StableDiffusionPipeline.from_pretrained(
"/data/scratch/diffuser/stable-diffusion-v1-4"