[example] add diffusion inference (#1986)

This commit is contained in:
Fazzie-Maqianli
2022-11-20 18:35:29 +08:00
committed by GitHub
parent a01278e810
commit b5dbb46172
7 changed files with 343 additions and 45 deletions

View File

@@ -99,12 +99,12 @@ class DDPM(pl.LightningModule):
self.use_positional_encodings = use_positional_encodings
self.unet_config = unet_config
self.conditioning_key = conditioning_key
# self.model = DiffusionWrapper(unet_config, conditioning_key)
# count_params(self.model, verbose=True)
self.model = DiffusionWrapper(unet_config, conditioning_key)
count_params(self.model, verbose=True)
self.use_ema = use_ema
# if self.use_ema:
# self.model_ema = LitEma(self.model)
# print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if self.use_ema:
self.model_ema = LitEma(self.model)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
@@ -125,20 +125,20 @@ class DDPM(pl.LightningModule):
self.linear_start = linear_start
self.linear_end = linear_end
self.cosine_s = cosine_s
# if ckpt_path is not None:
# self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
#
# self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
# linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
self.loss_type = loss_type
self.learn_logvar = learn_logvar
self.logvar_init = logvar_init
# self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
# if self.learn_logvar:
# self.logvar = nn.Parameter(self.logvar, requires_grad=True)
# self.logvar = nn.Parameter(self.logvar, requires_grad=True)
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
self.use_fp16 = use_fp16
if use_fp16:
@@ -312,14 +312,6 @@ class DDPM(pl.LightningModule):
def get_loss(self, pred, target, mean=True):
if pred.isnan().any():
print("Warning: Prediction has nan values")
lr = self.optimizers().param_groups[0]['lr']
# self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
print(f"lr: {lr}")
if pred.isinf().any():
print("Warning: Prediction has inf values")
if self.use_fp16:
target = target.half()
@@ -334,15 +326,6 @@ class DDPM(pl.LightningModule):
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
else:
raise NotImplementedError("unknown loss type '{loss_type}'")
if loss.isnan().any():
print("Warning: loss has nan values")
print("loss: ", loss[0][0][0])
raise ValueError("loss has nan values")
if loss.isinf().any():
print("Warning: loss has inf values")
print("loss: ", loss)
raise ValueError("loss has inf values")
return loss
@@ -382,11 +365,7 @@ class DDPM(pl.LightningModule):
return self.p_losses(x, t, *args, **kwargs)
def get_input(self, batch, k):
# print("+" * 30)
# print(batch['jpg'].shape)
# print(len(batch['txt']))
# print(k)
# print("=" * 30)
if not isinstance(batch, torch.Tensor):
x = batch[k]
else:
@@ -534,8 +513,8 @@ class LatentDiffusion(DDPM):
else:
self.cond_stage_config["params"].update({"use_fp16": False})
rank_zero_info("Using fp16 for conditioning stage = {}".format(self.cond_stage_config["params"]["use_fp16"]))
# self.instantiate_first_stage(first_stage_config)
# self.instantiate_cond_stage(cond_stage_config)
self.instantiate_first_stage(first_stage_config)
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.bbox_tokenizer = None
@@ -561,16 +540,11 @@ class LatentDiffusion(DDPM):
self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
# self.logvar = nn.Parameter(self.logvar, requires_grad=True)
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
if self.ckpt_path is not None:
self.init_from_ckpt(self.ckpt_path, self.ignore_keys)
self.restarted_from_ckpt = True
# TODO()
# for p in self.model.modules():
# if not p.parameters().data.is_contiguous:
# p.data = p.data.contiguous()
self.instantiate_first_stage(self.first_stage_config)
self.instantiate_cond_stage(self.cond_stage_config)