mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 14:12:02 +00:00
[example] add diffusion inference (#1986)
This commit is contained in:
@@ -99,12 +99,12 @@ class DDPM(pl.LightningModule):
|
||||
self.use_positional_encodings = use_positional_encodings
|
||||
self.unet_config = unet_config
|
||||
self.conditioning_key = conditioning_key
|
||||
# self.model = DiffusionWrapper(unet_config, conditioning_key)
|
||||
# count_params(self.model, verbose=True)
|
||||
self.model = DiffusionWrapper(unet_config, conditioning_key)
|
||||
count_params(self.model, verbose=True)
|
||||
self.use_ema = use_ema
|
||||
# if self.use_ema:
|
||||
# self.model_ema = LitEma(self.model)
|
||||
# print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self.model)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
self.use_scheduler = scheduler_config is not None
|
||||
if self.use_scheduler:
|
||||
@@ -125,20 +125,20 @@ class DDPM(pl.LightningModule):
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
self.cosine_s = cosine_s
|
||||
# if ckpt_path is not None:
|
||||
# self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
||||
#
|
||||
# self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
||||
# linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
||||
|
||||
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
||||
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||
|
||||
self.loss_type = loss_type
|
||||
|
||||
self.learn_logvar = learn_logvar
|
||||
self.logvar_init = logvar_init
|
||||
# self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
|
||||
# if self.learn_logvar:
|
||||
# self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
||||
# self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
||||
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
|
||||
if self.learn_logvar:
|
||||
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
||||
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
||||
|
||||
self.use_fp16 = use_fp16
|
||||
if use_fp16:
|
||||
@@ -312,14 +312,6 @@ class DDPM(pl.LightningModule):
|
||||
|
||||
def get_loss(self, pred, target, mean=True):
|
||||
|
||||
if pred.isnan().any():
|
||||
print("Warning: Prediction has nan values")
|
||||
lr = self.optimizers().param_groups[0]['lr']
|
||||
# self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
||||
print(f"lr: {lr}")
|
||||
if pred.isinf().any():
|
||||
print("Warning: Prediction has inf values")
|
||||
|
||||
if self.use_fp16:
|
||||
target = target.half()
|
||||
|
||||
@@ -334,15 +326,6 @@ class DDPM(pl.LightningModule):
|
||||
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
||||
else:
|
||||
raise NotImplementedError("unknown loss type '{loss_type}'")
|
||||
|
||||
if loss.isnan().any():
|
||||
print("Warning: loss has nan values")
|
||||
print("loss: ", loss[0][0][0])
|
||||
raise ValueError("loss has nan values")
|
||||
if loss.isinf().any():
|
||||
print("Warning: loss has inf values")
|
||||
print("loss: ", loss)
|
||||
raise ValueError("loss has inf values")
|
||||
|
||||
return loss
|
||||
|
||||
@@ -382,11 +365,7 @@ class DDPM(pl.LightningModule):
|
||||
return self.p_losses(x, t, *args, **kwargs)
|
||||
|
||||
def get_input(self, batch, k):
|
||||
# print("+" * 30)
|
||||
# print(batch['jpg'].shape)
|
||||
# print(len(batch['txt']))
|
||||
# print(k)
|
||||
# print("=" * 30)
|
||||
|
||||
if not isinstance(batch, torch.Tensor):
|
||||
x = batch[k]
|
||||
else:
|
||||
@@ -534,8 +513,8 @@ class LatentDiffusion(DDPM):
|
||||
else:
|
||||
self.cond_stage_config["params"].update({"use_fp16": False})
|
||||
rank_zero_info("Using fp16 for conditioning stage = {}".format(self.cond_stage_config["params"]["use_fp16"]))
|
||||
# self.instantiate_first_stage(first_stage_config)
|
||||
# self.instantiate_cond_stage(cond_stage_config)
|
||||
self.instantiate_first_stage(first_stage_config)
|
||||
self.instantiate_cond_stage(cond_stage_config)
|
||||
self.cond_stage_forward = cond_stage_forward
|
||||
self.clip_denoised = False
|
||||
self.bbox_tokenizer = None
|
||||
@@ -561,16 +540,11 @@ class LatentDiffusion(DDPM):
|
||||
self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,))
|
||||
if self.learn_logvar:
|
||||
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
||||
# self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
||||
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
||||
if self.ckpt_path is not None:
|
||||
self.init_from_ckpt(self.ckpt_path, self.ignore_keys)
|
||||
self.restarted_from_ckpt = True
|
||||
|
||||
# TODO()
|
||||
# for p in self.model.modules():
|
||||
# if not p.parameters().data.is_contiguous:
|
||||
# p.data = p.data.contiguous()
|
||||
|
||||
self.instantiate_first_stage(self.first_stage_config)
|
||||
self.instantiate_cond_stage(self.cond_stage_config)
|
||||
|
||||
|
Reference in New Issue
Block a user