mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
This commit is contained in:
@@ -9,7 +9,7 @@ from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
from ldm.util import instantiate_from_config
|
||||
from main import get_parser
|
||||
from ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||
|
||||
if __name__ == "__main__":
|
||||
with torch.no_grad():
|
||||
yaml_path = "../../train_colossalai.yaml"
|
||||
@@ -17,7 +17,7 @@ if __name__ == "__main__":
|
||||
config = f.read()
|
||||
base_config = yaml.load(config, Loader=yaml.FullLoader)
|
||||
unet_config = base_config['model']['params']['unet_config']
|
||||
diffusion_model = UNetModel(**unet_config.get("params", dict())).to("cuda:0")
|
||||
diffusion_model = instantiate_from_config(unet_config).to("cuda:0")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"/data/scratch/diffuser/stable-diffusion-v1-4"
|
||||
|
Reference in New Issue
Block a user