ColossalAI/examples/inference/stable_diffusion/compute_metric.py
Runyu Lu bcf0181ecd
[Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement
2024-07-30 10:43:26 +08:00

81 lines
2.9 KiB
Python

# Code from https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py
import argparse
import os
import numpy as np
import torch
from cleanfid import fid
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio
from torchvision.transforms import Resize
from tqdm import tqdm
def read_image(path: str):
"""
input: path
output: tensor (C, H, W)
"""
img = np.asarray(Image.open(path))
if len(img.shape) == 2:
img = np.repeat(img[:, :, None], 3, axis=2)
img = torch.from_numpy(img).permute(2, 0, 1)
return img
class MultiImageDataset(Dataset):
def __init__(self, root0, root1, is_gt=False):
super().__init__()
self.root0 = root0
self.root1 = root1
file_names0 = os.listdir(root0)
file_names1 = os.listdir(root1)
self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")])
self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")])
self.is_gt = is_gt
assert len(self.image_names0) == len(self.image_names1)
def __len__(self):
return len(self.image_names0)
def __getitem__(self, idx):
img0 = read_image(os.path.join(self.root0, self.image_names0[idx]))
if self.is_gt:
# resize to 1024 x 1024
img0 = Resize((1024, 1024))(img0)
img1 = read_image(os.path.join(self.root1, self.image_names1[idx]))
batch_list = [img0, img1]
return batch_list
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--is_gt", action="store_true")
parser.add_argument("--input_root0", type=str, required=True)
parser.add_argument("--input_root1", type=str, required=True)
args = parser.parse_args()
psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda")
lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda")
dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt)
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
progress_bar = tqdm(dataloader)
with torch.inference_mode():
for i, batch in enumerate(progress_bar):
batch = [img.to("cuda") / 255 for img in batch]
batch_size = batch[0].shape[0]
psnr.update(batch[0], batch[1])
lpips.update(batch[0], batch[1])
fid_score = fid.compute_fid(args.input_root0, args.input_root1)
print("PSNR:", psnr.compute().item())
print("LPIPS:", lpips.compute().item())
print("FID:", fid_score)