[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,17 +1,15 @@
import math
import os
from abc import abstractmethod
import cv2
import numpy as np
import torch
from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset
from torch.utils.data import IterableDataset
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
"""
Define an interface to make the IterableDatasets for text2img data chainable
'''
"""
def __init__(self, file_path: str, rank, world_size):
super().__init__()
@@ -20,8 +18,8 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
self.file_list = []
self.txt_list = []
self.info = self._get_file_info(file_path)
self.start = self.info['start']
self.end = self.info['end']
self.start = self.info["start"]
self.end = self.info["end"]
self.rank = rank
self.world_size = world_size
@@ -33,7 +31,7 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
self.num_records = self.end - self.start
self.valid_ids = [i for i in range(self.end)]
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.")
def __len__(self):
# return self.iter_end - self.iter_start
@@ -48,7 +46,7 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
for idx in range(start, end):
file_name = self.file_list[idx]
txt_name = self.txt_list[idx]
f_ = open(txt_name, 'r')
f_ = open(txt_name, "r")
txt_ = f_.read()
f_.close()
image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1)
@@ -57,18 +55,17 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
yield {"txt": txt_, "image": image}
def _get_file_info(self, file_path):
info = \
{
info = {
"start": 1,
"end": 0,
}
self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i]
self.folder_list = [file_path + i for i in os.listdir(file_path) if "." not in i]
for folder in self.folder_list:
files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i]
txts = [k.replace('jpg', 'txt') for k in files]
files = [folder + "/" + i for i in os.listdir(folder) if "jpg" in i]
txts = [k.replace("jpg", "txt") for k in files]
self.file_list.extend(files)
self.txt_list.extend(txts)
info['end'] = len(self.file_list)
info["end"] = len(self.file_list)
# with open(file_path, 'r') as fin:
# for _ in enumerate(fin):
# info['end'] += 1

View File

@@ -1,15 +1,16 @@
from typing import Dict
import numpy as np
from omegaconf import DictConfig, ListConfig
import torch
from torch.utils.data import Dataset
from pathlib import Path
import json
from PIL import Image
from torchvision import transforms
from pathlib import Path
from typing import Dict
import torch
from datasets import load_dataset
from einops import rearrange
from ldm.util import instantiate_from_config
from datasets import load_dataset
from omegaconf import DictConfig, ListConfig
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
def make_multi_folder_data(paths, caption_files=None, **kwargs):
"""Make a concat dataset from multiple folders
@@ -19,10 +20,9 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs):
"""
list_of_paths = []
if isinstance(paths, (Dict, DictConfig)):
assert caption_files is None, \
"Caption files not yet supported for repeats"
assert caption_files is None, "Caption files not yet supported for repeats"
for folder_path, repeats in paths.items():
list_of_paths.extend([folder_path]*repeats)
list_of_paths.extend([folder_path] * repeats)
paths = list_of_paths
if caption_files is not None:
@@ -31,8 +31,10 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs):
datasets = [FolderData(p, **kwargs) for p in paths]
return torch.utils.data.ConcatDataset(datasets)
class FolderData(Dataset):
def __init__(self,
def __init__(
self,
root_dir,
caption_file=None,
image_transforms=[],
@@ -40,7 +42,7 @@ class FolderData(Dataset):
default_caption="",
postprocess=None,
return_paths=False,
) -> None:
) -> None:
"""Create a dataset from a folder of images.
If you pass in a root directory it will be searched for images
ending in ext (ext can be a list)
@@ -75,12 +77,12 @@ class FolderData(Dataset):
self.paths.extend(list(self.root_dir.rglob(f"*.{e}")))
if isinstance(image_transforms, ListConfig):
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
image_transforms.extend([transforms.ToTensor(),
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
image_transforms.extend(
[transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))]
)
image_transforms = transforms.Compose(image_transforms)
self.tform = image_transforms
def __len__(self):
if self.captions is not None:
return len(self.captions.keys())
@@ -94,7 +96,7 @@ class FolderData(Dataset):
caption = self.captions.get(chosen, None)
if caption is None:
caption = self.default_caption
filename = self.root_dir/chosen
filename = self.root_dir / chosen
else:
filename = self.paths[index]
@@ -119,22 +121,23 @@ class FolderData(Dataset):
im = im.convert("RGB")
return self.tform(im)
def hf_dataset(
name,
image_transforms=[],
image_column="img",
label_column="label",
text_column="txt",
split='train',
image_key='image',
caption_key='txt',
):
"""Make huggingface dataset with appropriate list of transforms applied
"""
split="train",
image_key="image",
caption_key="txt",
):
"""Make huggingface dataset with appropriate list of transforms applied"""
ds = load_dataset(name, split=split)
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
image_transforms.extend([transforms.ToTensor(),
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
image_transforms.extend(
[transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))]
)
tform = transforms.Compose(image_transforms)
assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
@@ -144,7 +147,18 @@ def hf_dataset(
processed = {}
processed[image_key] = [tform(im) for im in examples[image_column]]
label_to_text_dict = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"}
label_to_text_dict = {
0: "airplane",
1: "automobile",
2: "bird",
3: "cat",
4: "deer",
5: "dog",
6: "frog",
7: "horse",
8: "ship",
9: "truck",
}
processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]]
@@ -153,6 +167,7 @@ def hf_dataset(
ds.set_transform(pre_process)
return ds
class TextOnly(Dataset):
def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1):
"""Returns only captions with dummy images"""
@@ -166,7 +181,7 @@ class TextOnly(Dataset):
if n_gpus > 1:
# hack to make sure that all the captions appear on each gpu
repeated = [n_gpus*[x] for x in self.captions]
repeated = [n_gpus * [x] for x in self.captions]
self.captions = []
[self.captions.extend(x) for x in repeated]
@@ -175,10 +190,10 @@ class TextOnly(Dataset):
def __getitem__(self, index):
dummy_im = torch.zeros(3, self.output_size, self.output_size)
dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c')
dummy_im = rearrange(dummy_im * 2.0 - 1.0, "c h w -> h w c")
return {self.image_key: dummy_im, self.caption_key: self.captions[index]}
def _load_caption_file(self, filename):
with open(filename, 'rt') as f:
with open(filename, "rt") as f:
captions = f.readlines()
return [x.strip('\n') for x in captions]
return [x.strip("\n") for x in captions]

View File

@@ -1,32 +1,35 @@
import os, yaml, pickle, shutil, tarfile, glob
import cv2
import albumentations
import PIL
import numpy as np
import torchvision.transforms.functional as TF
from omegaconf import OmegaConf
import glob
import os
import pickle
import shutil
import tarfile
from functools import partial
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, Subset
import albumentations
import cv2
import numpy as np
import PIL
import taming.data.utils as tdu
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
from taming.data.imagenet import ImagePaths
import torchvision.transforms.functional as TF
import yaml
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
from omegaconf import OmegaConf
from PIL import Image
from taming.data.imagenet import ImagePaths, download, give_synsets_from_indices, retrieve, str_to_indices
from torch.utils.data import Dataset, Subset
from tqdm import tqdm
def synset2idx(path_to_yaml="data/index_synset.yaml"):
with open(path_to_yaml) as f:
di2s = yaml.load(f)
return dict((v,k) for k,v in di2s.items())
return dict((v, k) for k, v in di2s.items())
class ImageNetBase(Dataset):
def __init__(self, config=None):
self.config = config or OmegaConf.create()
if not type(self.config)==dict:
if not type(self.config) == dict:
self.config = OmegaConf.to_container(self.config)
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
@@ -46,9 +49,11 @@ class ImageNetBase(Dataset):
raise NotImplementedError()
def _filter_relpaths(self, relpaths):
ignore = set([
"n06596364_9591.JPEG",
])
ignore = set(
[
"n06596364_9591.JPEG",
]
)
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
if "sub_indices" in self.config:
indices = str_to_indices(self.config["sub_indices"])
@@ -67,20 +72,19 @@ class ImageNetBase(Dataset):
SIZE = 2655750
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
self.human_dict = os.path.join(self.root, "synset_human.txt")
if (not os.path.exists(self.human_dict) or
not os.path.getsize(self.human_dict)==SIZE):
if not os.path.exists(self.human_dict) or not os.path.getsize(self.human_dict) == SIZE:
download(URL, self.human_dict)
def _prepare_idx_to_synset(self):
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
if (not os.path.exists(self.idx2syn)):
if not os.path.exists(self.idx2syn):
download(URL, self.idx2syn)
def _prepare_human_to_integer_label(self):
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
if (not os.path.exists(self.human2integer)):
if not os.path.exists(self.human2integer):
download(URL, self.human2integer)
with open(self.human2integer, "r") as f:
lines = f.read().splitlines()
@@ -122,11 +126,12 @@ class ImageNetBase(Dataset):
if self.process_images:
self.size = retrieve(self.config, "size", default=256)
self.data = ImagePaths(self.abspaths,
labels=labels,
size=self.size,
random_crop=self.random_crop,
)
self.data = ImagePaths(
self.abspaths,
labels=labels,
size=self.size,
random_crop=self.random_crop,
)
else:
self.data = self.abspaths
@@ -157,8 +162,7 @@ class ImageNetTrain(ImageNetBase):
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.expected_length = 1281167
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
default=True)
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", default=True)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
@@ -166,8 +170,9 @@ class ImageNetTrain(ImageNetBase):
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]:
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
@@ -179,7 +184,7 @@ class ImageNetTrain(ImageNetBase):
print("Extracting sub-tars.")
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
for subpath in tqdm(subpaths):
subdir = subpath[:-len(".tar")]
subdir = subpath[: -len(".tar")]
os.makedirs(subdir, exist_ok=True)
with tarfile.open(subpath, "r:") as tar:
tar.extractall(path=subdir)
@@ -187,7 +192,7 @@ class ImageNetTrain(ImageNetBase):
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
filelist = "\n".join(filelist) + "\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
@@ -222,8 +227,7 @@ class ImageNetValidation(ImageNetBase):
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.expected_length = 50000
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
default=False)
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", default=False)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
@@ -231,8 +235,9 @@ class ImageNetValidation(ImageNetBase):
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]:
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
@@ -242,7 +247,7 @@ class ImageNetValidation(ImageNetBase):
tar.extractall(path=datadir)
vspath = os.path.join(self.root, self.FILES[1])
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
if not os.path.exists(vspath) or not os.path.getsize(vspath) == self.SIZES[1]:
download(self.VS_URL, vspath)
with open(vspath, "r") as f:
@@ -261,18 +266,15 @@ class ImageNetValidation(ImageNetBase):
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
filelist = "\n".join(filelist) + "\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetSR(Dataset):
def __init__(self, size=None,
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
random_crop=True):
def __init__(self, size=None, degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.0, random_crop=True):
"""
Imagenet Superresolution Dataloader
Performs following ops in order:
@@ -296,12 +298,12 @@ class ImageNetSR(Dataset):
self.LR_size = int(size / downscale_f)
self.min_crop_f = min_crop_f
self.max_crop_f = max_crop_f
assert(max_crop_f <= 1.)
assert max_crop_f <= 1.0
self.center_crop = not random_crop
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
if degradation == "bsrgan":
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
@@ -311,17 +313,17 @@ class ImageNetSR(Dataset):
else:
interpolation_fn = {
"cv_nearest": cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4,
"pil_nearest": PIL.Image.NEAREST,
"pil_bilinear": PIL.Image.BILINEAR,
"pil_bicubic": PIL.Image.BICUBIC,
"pil_box": PIL.Image.BOX,
"pil_hamming": PIL.Image.HAMMING,
"pil_lanczos": PIL.Image.LANCZOS,
"cv_nearest": cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4,
"pil_nearest": PIL.Image.NEAREST,
"pil_bilinear": PIL.Image.BILINEAR,
"pil_bicubic": PIL.Image.BICUBIC,
"pil_box": PIL.Image.BOX,
"pil_hamming": PIL.Image.HAMMING,
"pil_lanczos": PIL.Image.LANCZOS,
}[degradation]
self.pil_interpolation = degradation.startswith("pil_")
@@ -330,8 +332,9 @@ class ImageNetSR(Dataset):
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
else:
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
interpolation=interpolation_fn)
self.degradation_process = albumentations.SmallestMaxSize(
max_size=self.LR_size, interpolation=interpolation_fn
)
def __len__(self):
return len(self.base)
@@ -366,8 +369,8 @@ class ImageNetSR(Dataset):
else:
LR_image = self.degradation_process(image=image)["image"]
example["image"] = (image/127.5 - 1.0).astype(np.float32)
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
example["LR_image"] = (LR_image / 127.5 - 1.0).astype(np.float32)
return example
@@ -379,7 +382,9 @@ class ImageNetSRTrain(ImageNetSR):
def get_base(self):
with open("data/imagenet_train_hr_indices.p", "rb") as f:
indices = pickle.load(f)
dset = ImageNetTrain(process_images=False,)
dset = ImageNetTrain(
process_images=False,
)
return Subset(dset, indices)
@@ -390,5 +395,7 @@ class ImageNetSRValidation(ImageNetSR):
def get_base(self):
with open("data/imagenet_val_hr_indices.p", "rb") as f:
indices = pickle.load(f)
dset = ImageNetValidation(process_images=False,)
dset = ImageNetValidation(
process_images=False,
)
return Subset(dset, indices)

View File

@@ -1,47 +1,49 @@
import os
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
# This class is used to create a dataset of images from LSUN dataset for training
class LSUNBase(Dataset):
def __init__(self,
txt_file, # path to the text file containing the list of image paths
data_root, # root directory of the LSUN dataset
size=None, # the size of images to resize to
interpolation="bicubic", # interpolation method to be used while resizing
flip_p=0.5 # probability of random horizontal flipping
):
self.data_paths = txt_file # store path to text file containing list of images
self.data_root = data_root # store path to root directory of the dataset
with open(self.data_paths, "r") as f: # open and read the text file
self.image_paths = f.read().splitlines() # read the lines of the file and store as list
self._length = len(self.image_paths) # store the number of images
def __init__(
self,
txt_file, # path to the text file containing the list of image paths
data_root, # root directory of the LSUN dataset
size=None, # the size of images to resize to
interpolation="bicubic", # interpolation method to be used while resizing
flip_p=0.5, # probability of random horizontal flipping
):
self.data_paths = txt_file # store path to text file containing list of images
self.data_root = data_root # store path to root directory of the dataset
with open(self.data_paths, "r") as f: # open and read the text file
self.image_paths = f.read().splitlines() # read the lines of the file and store as list
self._length = len(self.image_paths) # store the number of images
# create dictionary to hold image path information
self.labels = {
"relative_file_path_": [l for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, l)
for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, l) for l in self.image_paths],
}
# set the image size to be resized
self.size = size
self.size = size
# set the interpolation method for resizing the image
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
# randomly flip the image horizontally with a given probability
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
# return the length of dataset
return self._length
def __getitem__(self, i):
# get the image path for the given index
@@ -52,59 +54,71 @@ class LSUNBase(Dataset):
image = image.convert("RGB")
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8) # convert image to numpy array
crop = min(img.shape[0], img.shape[1]) # crop the image to a square shape
h, w, = img.shape[0], img.shape[1] # get the height and width of image
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2] # crop the image to a square shape
image = Image.fromarray(img) # create an image from numpy array
if self.size is not None: # if image size is provided, resize the image
img = np.array(image).astype(np.uint8) # convert image to numpy array
crop = min(img.shape[0], img.shape[1]) # crop the image to a square shape
(
h,
w,
) = (
img.shape[0],
img.shape[1],
) # get the height and width of image
img = img[
(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2
] # crop the image to a square shape
image = Image.fromarray(img) # create an image from numpy array
if self.size is not None: # if image size is provided, resize the image
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image) # flip the image horizontally with the given probability
image = np.array(image).astype(np.uint8)
image = self.flip(image) # flip the image horizontally with the given probability
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32) # normalize the image values and convert to float32
return example # return the example dictionary containing the image and its file paths
return example # return the example dictionary containing the image and its file paths
#A dataset class for LSUN Churches training set.
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
# A dataset class for LSUN Churches training set.
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
# The text file containing the paths to the images and the root directory where the images are stored are passed as arguments. Any additional keyword arguments passed to this class will be forwarded to the constructor of the parent class.
class LSUNChurchesTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
#A dataset class for LSUN Churches validation set.
# A dataset class for LSUN Churches validation set.
# It is similar to LSUNChurchesTrain except that it uses a different text file and sets the flip probability to zero by default.
class LSUNChurchesValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
flip_p=flip_p, **kwargs)
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(
txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", flip_p=flip_p, **kwargs
)
# A dataset class for LSUN Bedrooms training set.
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
# A dataset class for LSUN Bedrooms training set.
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
class LSUNBedroomsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
# A dataset class for LSUN Bedrooms validation set.
# A dataset class for LSUN Bedrooms validation set.
# It is similar to LSUNBedroomsTrain except that it uses a different text file and sets the flip probability to zero by default.
class LSUNBedroomsValidation(LSUNBase):
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
flip_p=flip_p, **kwargs)
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", flip_p=flip_p, **kwargs)
# A dataset class for LSUN Cats training set.
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
# A dataset class for LSUN Cats training set.
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
# The text file containing the paths to the images and the root directory where the images are stored are passed as arguments.
class LSUNCatsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
# A dataset class for LSUN Cats validation set.
# A dataset class for LSUN Cats validation set.
# It is similar to LSUNCatsTrain except that it uses a different text file and sets the flip probability to zero by default.
class LSUNCatsValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
flip_p=flip_p, **kwargs)
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", flip_p=flip_p, **kwargs)

View File

@@ -1,15 +1,16 @@
from typing import Dict
import numpy as np
from omegaconf import DictConfig, ListConfig
import torch
from torch.utils.data import Dataset
from pathlib import Path
import json
from PIL import Image
from torchvision import transforms
from pathlib import Path
from typing import Dict
import torch
from datasets import load_dataset
from einops import rearrange
from ldm.util import instantiate_from_config
from datasets import load_dataset
from omegaconf import DictConfig, ListConfig
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
def make_multi_folder_data(paths, caption_files=None, **kwargs):
"""Make a concat dataset from multiple folders
@@ -19,10 +20,9 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs):
"""
list_of_paths = []
if isinstance(paths, (Dict, DictConfig)):
assert caption_files is None, \
"Caption files not yet supported for repeats"
assert caption_files is None, "Caption files not yet supported for repeats"
for folder_path, repeats in paths.items():
list_of_paths.extend([folder_path]*repeats)
list_of_paths.extend([folder_path] * repeats)
paths = list_of_paths
if caption_files is not None:
@@ -31,8 +31,10 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs):
datasets = [FolderData(p, **kwargs) for p in paths]
return torch.utils.data.ConcatDataset(datasets)
class FolderData(Dataset):
def __init__(self,
def __init__(
self,
root_dir,
caption_file=None,
image_transforms=[],
@@ -40,7 +42,7 @@ class FolderData(Dataset):
default_caption="",
postprocess=None,
return_paths=False,
) -> None:
) -> None:
"""Create a dataset from a folder of images.
If you pass in a root directory it will be searched for images
ending in ext (ext can be a list)
@@ -75,12 +77,12 @@ class FolderData(Dataset):
self.paths.extend(list(self.root_dir.rglob(f"*.{e}")))
if isinstance(image_transforms, ListConfig):
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
image_transforms.extend([transforms.ToTensor(),
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
image_transforms.extend(
[transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))]
)
image_transforms = transforms.Compose(image_transforms)
self.tform = image_transforms
def __len__(self):
if self.captions is not None:
return len(self.captions.keys())
@@ -94,7 +96,7 @@ class FolderData(Dataset):
caption = self.captions.get(chosen, None)
if caption is None:
caption = self.default_caption
filename = self.root_dir/chosen
filename = self.root_dir / chosen
else:
filename = self.paths[index]
@@ -119,23 +121,26 @@ class FolderData(Dataset):
im = im.convert("RGB")
return self.tform(im)
def hf_dataset(
path = "Fazzie/Teyvat",
path="Fazzie/Teyvat",
image_transforms=[],
image_column="image",
text_column="text",
image_key='image',
caption_key='txt',
):
"""Make huggingface dataset with appropriate list of transforms applied
"""
image_key="image",
caption_key="txt",
):
"""Make huggingface dataset with appropriate list of transforms applied"""
ds = load_dataset(path, name="train")
ds = ds["train"]
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
image_transforms.extend([transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]
)
image_transforms.extend(
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c")),
]
)
tform = transforms.Compose(image_transforms)
assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
@@ -149,4 +154,4 @@ def hf_dataset(
return processed
ds.set_transform(pre_process)
return ds
return ds