mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,17 +1,15 @@
|
||||
import math
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
|
||||
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
'''
|
||||
"""
|
||||
Define an interface to make the IterableDatasets for text2img data chainable
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, rank, world_size):
|
||||
super().__init__()
|
||||
@@ -20,8 +18,8 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
self.file_list = []
|
||||
self.txt_list = []
|
||||
self.info = self._get_file_info(file_path)
|
||||
self.start = self.info['start']
|
||||
self.end = self.info['end']
|
||||
self.start = self.info["start"]
|
||||
self.end = self.info["end"]
|
||||
self.rank = rank
|
||||
|
||||
self.world_size = world_size
|
||||
@@ -33,7 +31,7 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
self.num_records = self.end - self.start
|
||||
self.valid_ids = [i for i in range(self.end)]
|
||||
|
||||
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
|
||||
print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.")
|
||||
|
||||
def __len__(self):
|
||||
# return self.iter_end - self.iter_start
|
||||
@@ -48,7 +46,7 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
for idx in range(start, end):
|
||||
file_name = self.file_list[idx]
|
||||
txt_name = self.txt_list[idx]
|
||||
f_ = open(txt_name, 'r')
|
||||
f_ = open(txt_name, "r")
|
||||
txt_ = f_.read()
|
||||
f_.close()
|
||||
image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1)
|
||||
@@ -57,18 +55,17 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
yield {"txt": txt_, "image": image}
|
||||
|
||||
def _get_file_info(self, file_path):
|
||||
info = \
|
||||
{
|
||||
info = {
|
||||
"start": 1,
|
||||
"end": 0,
|
||||
}
|
||||
self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i]
|
||||
self.folder_list = [file_path + i for i in os.listdir(file_path) if "." not in i]
|
||||
for folder in self.folder_list:
|
||||
files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i]
|
||||
txts = [k.replace('jpg', 'txt') for k in files]
|
||||
files = [folder + "/" + i for i in os.listdir(folder) if "jpg" in i]
|
||||
txts = [k.replace("jpg", "txt") for k in files]
|
||||
self.file_list.extend(files)
|
||||
self.txt_list.extend(txts)
|
||||
info['end'] = len(self.file_list)
|
||||
info["end"] = len(self.file_list)
|
||||
# with open(file_path, 'r') as fin:
|
||||
# for _ in enumerate(fin):
|
||||
# info['end'] += 1
|
||||
|
@@ -1,15 +1,16 @@
|
||||
from typing import Dict
|
||||
import numpy as np
|
||||
from omegaconf import DictConfig, ListConfig
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from pathlib import Path
|
||||
import json
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from einops import rearrange
|
||||
from ldm.util import instantiate_from_config
|
||||
from datasets import load_dataset
|
||||
from omegaconf import DictConfig, ListConfig
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def make_multi_folder_data(paths, caption_files=None, **kwargs):
|
||||
"""Make a concat dataset from multiple folders
|
||||
@@ -19,10 +20,9 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs):
|
||||
"""
|
||||
list_of_paths = []
|
||||
if isinstance(paths, (Dict, DictConfig)):
|
||||
assert caption_files is None, \
|
||||
"Caption files not yet supported for repeats"
|
||||
assert caption_files is None, "Caption files not yet supported for repeats"
|
||||
for folder_path, repeats in paths.items():
|
||||
list_of_paths.extend([folder_path]*repeats)
|
||||
list_of_paths.extend([folder_path] * repeats)
|
||||
paths = list_of_paths
|
||||
|
||||
if caption_files is not None:
|
||||
@@ -31,8 +31,10 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs):
|
||||
datasets = [FolderData(p, **kwargs) for p in paths]
|
||||
return torch.utils.data.ConcatDataset(datasets)
|
||||
|
||||
|
||||
class FolderData(Dataset):
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
root_dir,
|
||||
caption_file=None,
|
||||
image_transforms=[],
|
||||
@@ -40,7 +42,7 @@ class FolderData(Dataset):
|
||||
default_caption="",
|
||||
postprocess=None,
|
||||
return_paths=False,
|
||||
) -> None:
|
||||
) -> None:
|
||||
"""Create a dataset from a folder of images.
|
||||
If you pass in a root directory it will be searched for images
|
||||
ending in ext (ext can be a list)
|
||||
@@ -75,12 +77,12 @@ class FolderData(Dataset):
|
||||
self.paths.extend(list(self.root_dir.rglob(f"*.{e}")))
|
||||
if isinstance(image_transforms, ListConfig):
|
||||
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||
image_transforms.extend([transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
||||
image_transforms.extend(
|
||||
[transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))]
|
||||
)
|
||||
image_transforms = transforms.Compose(image_transforms)
|
||||
self.tform = image_transforms
|
||||
|
||||
|
||||
def __len__(self):
|
||||
if self.captions is not None:
|
||||
return len(self.captions.keys())
|
||||
@@ -94,7 +96,7 @@ class FolderData(Dataset):
|
||||
caption = self.captions.get(chosen, None)
|
||||
if caption is None:
|
||||
caption = self.default_caption
|
||||
filename = self.root_dir/chosen
|
||||
filename = self.root_dir / chosen
|
||||
else:
|
||||
filename = self.paths[index]
|
||||
|
||||
@@ -119,22 +121,23 @@ class FolderData(Dataset):
|
||||
im = im.convert("RGB")
|
||||
return self.tform(im)
|
||||
|
||||
|
||||
def hf_dataset(
|
||||
name,
|
||||
image_transforms=[],
|
||||
image_column="img",
|
||||
label_column="label",
|
||||
text_column="txt",
|
||||
split='train',
|
||||
image_key='image',
|
||||
caption_key='txt',
|
||||
):
|
||||
"""Make huggingface dataset with appropriate list of transforms applied
|
||||
"""
|
||||
split="train",
|
||||
image_key="image",
|
||||
caption_key="txt",
|
||||
):
|
||||
"""Make huggingface dataset with appropriate list of transforms applied"""
|
||||
ds = load_dataset(name, split=split)
|
||||
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||
image_transforms.extend([transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
||||
image_transforms.extend(
|
||||
[transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))]
|
||||
)
|
||||
tform = transforms.Compose(image_transforms)
|
||||
|
||||
assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
|
||||
@@ -144,7 +147,18 @@ def hf_dataset(
|
||||
processed = {}
|
||||
processed[image_key] = [tform(im) for im in examples[image_column]]
|
||||
|
||||
label_to_text_dict = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"}
|
||||
label_to_text_dict = {
|
||||
0: "airplane",
|
||||
1: "automobile",
|
||||
2: "bird",
|
||||
3: "cat",
|
||||
4: "deer",
|
||||
5: "dog",
|
||||
6: "frog",
|
||||
7: "horse",
|
||||
8: "ship",
|
||||
9: "truck",
|
||||
}
|
||||
|
||||
processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]]
|
||||
|
||||
@@ -153,6 +167,7 @@ def hf_dataset(
|
||||
ds.set_transform(pre_process)
|
||||
return ds
|
||||
|
||||
|
||||
class TextOnly(Dataset):
|
||||
def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1):
|
||||
"""Returns only captions with dummy images"""
|
||||
@@ -166,7 +181,7 @@ class TextOnly(Dataset):
|
||||
|
||||
if n_gpus > 1:
|
||||
# hack to make sure that all the captions appear on each gpu
|
||||
repeated = [n_gpus*[x] for x in self.captions]
|
||||
repeated = [n_gpus * [x] for x in self.captions]
|
||||
self.captions = []
|
||||
[self.captions.extend(x) for x in repeated]
|
||||
|
||||
@@ -175,10 +190,10 @@ class TextOnly(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
dummy_im = torch.zeros(3, self.output_size, self.output_size)
|
||||
dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c')
|
||||
dummy_im = rearrange(dummy_im * 2.0 - 1.0, "c h w -> h w c")
|
||||
return {self.image_key: dummy_im, self.caption_key: self.captions[index]}
|
||||
|
||||
def _load_caption_file(self, filename):
|
||||
with open(filename, 'rt') as f:
|
||||
with open(filename, "rt") as f:
|
||||
captions = f.readlines()
|
||||
return [x.strip('\n') for x in captions]
|
||||
return [x.strip("\n") for x in captions]
|
||||
|
@@ -1,32 +1,35 @@
|
||||
import os, yaml, pickle, shutil, tarfile, glob
|
||||
import cv2
|
||||
import albumentations
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torchvision.transforms.functional as TF
|
||||
from omegaconf import OmegaConf
|
||||
import glob
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import tarfile
|
||||
from functools import partial
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import Dataset, Subset
|
||||
|
||||
import albumentations
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import taming.data.utils as tdu
|
||||
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
|
||||
from taming.data.imagenet import ImagePaths
|
||||
|
||||
import torchvision.transforms.functional as TF
|
||||
import yaml
|
||||
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from taming.data.imagenet import ImagePaths, download, give_synsets_from_indices, retrieve, str_to_indices
|
||||
from torch.utils.data import Dataset, Subset
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def synset2idx(path_to_yaml="data/index_synset.yaml"):
|
||||
with open(path_to_yaml) as f:
|
||||
di2s = yaml.load(f)
|
||||
return dict((v,k) for k,v in di2s.items())
|
||||
return dict((v, k) for k, v in di2s.items())
|
||||
|
||||
|
||||
class ImageNetBase(Dataset):
|
||||
def __init__(self, config=None):
|
||||
self.config = config or OmegaConf.create()
|
||||
if not type(self.config)==dict:
|
||||
if not type(self.config) == dict:
|
||||
self.config = OmegaConf.to_container(self.config)
|
||||
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
|
||||
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
|
||||
@@ -46,9 +49,11 @@ class ImageNetBase(Dataset):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _filter_relpaths(self, relpaths):
|
||||
ignore = set([
|
||||
"n06596364_9591.JPEG",
|
||||
])
|
||||
ignore = set(
|
||||
[
|
||||
"n06596364_9591.JPEG",
|
||||
]
|
||||
)
|
||||
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
|
||||
if "sub_indices" in self.config:
|
||||
indices = str_to_indices(self.config["sub_indices"])
|
||||
@@ -67,20 +72,19 @@ class ImageNetBase(Dataset):
|
||||
SIZE = 2655750
|
||||
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
|
||||
self.human_dict = os.path.join(self.root, "synset_human.txt")
|
||||
if (not os.path.exists(self.human_dict) or
|
||||
not os.path.getsize(self.human_dict)==SIZE):
|
||||
if not os.path.exists(self.human_dict) or not os.path.getsize(self.human_dict) == SIZE:
|
||||
download(URL, self.human_dict)
|
||||
|
||||
def _prepare_idx_to_synset(self):
|
||||
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
|
||||
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
|
||||
if (not os.path.exists(self.idx2syn)):
|
||||
if not os.path.exists(self.idx2syn):
|
||||
download(URL, self.idx2syn)
|
||||
|
||||
def _prepare_human_to_integer_label(self):
|
||||
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
|
||||
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
|
||||
if (not os.path.exists(self.human2integer)):
|
||||
if not os.path.exists(self.human2integer):
|
||||
download(URL, self.human2integer)
|
||||
with open(self.human2integer, "r") as f:
|
||||
lines = f.read().splitlines()
|
||||
@@ -122,11 +126,12 @@ class ImageNetBase(Dataset):
|
||||
|
||||
if self.process_images:
|
||||
self.size = retrieve(self.config, "size", default=256)
|
||||
self.data = ImagePaths(self.abspaths,
|
||||
labels=labels,
|
||||
size=self.size,
|
||||
random_crop=self.random_crop,
|
||||
)
|
||||
self.data = ImagePaths(
|
||||
self.abspaths,
|
||||
labels=labels,
|
||||
size=self.size,
|
||||
random_crop=self.random_crop,
|
||||
)
|
||||
else:
|
||||
self.data = self.abspaths
|
||||
|
||||
@@ -157,8 +162,7 @@ class ImageNetTrain(ImageNetBase):
|
||||
self.datadir = os.path.join(self.root, "data")
|
||||
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
||||
self.expected_length = 1281167
|
||||
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
|
||||
default=True)
|
||||
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", default=True)
|
||||
if not tdu.is_prepared(self.root):
|
||||
# prep
|
||||
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
||||
@@ -166,8 +170,9 @@ class ImageNetTrain(ImageNetBase):
|
||||
datadir = self.datadir
|
||||
if not os.path.exists(datadir):
|
||||
path = os.path.join(self.root, self.FILES[0])
|
||||
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
||||
if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]:
|
||||
import academictorrents as at
|
||||
|
||||
atpath = at.get(self.AT_HASH, datastore=self.root)
|
||||
assert atpath == path
|
||||
|
||||
@@ -179,7 +184,7 @@ class ImageNetTrain(ImageNetBase):
|
||||
print("Extracting sub-tars.")
|
||||
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
|
||||
for subpath in tqdm(subpaths):
|
||||
subdir = subpath[:-len(".tar")]
|
||||
subdir = subpath[: -len(".tar")]
|
||||
os.makedirs(subdir, exist_ok=True)
|
||||
with tarfile.open(subpath, "r:") as tar:
|
||||
tar.extractall(path=subdir)
|
||||
@@ -187,7 +192,7 @@ class ImageNetTrain(ImageNetBase):
|
||||
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
||||
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
||||
filelist = sorted(filelist)
|
||||
filelist = "\n".join(filelist)+"\n"
|
||||
filelist = "\n".join(filelist) + "\n"
|
||||
with open(self.txt_filelist, "w") as f:
|
||||
f.write(filelist)
|
||||
|
||||
@@ -222,8 +227,7 @@ class ImageNetValidation(ImageNetBase):
|
||||
self.datadir = os.path.join(self.root, "data")
|
||||
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
||||
self.expected_length = 50000
|
||||
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
|
||||
default=False)
|
||||
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", default=False)
|
||||
if not tdu.is_prepared(self.root):
|
||||
# prep
|
||||
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
||||
@@ -231,8 +235,9 @@ class ImageNetValidation(ImageNetBase):
|
||||
datadir = self.datadir
|
||||
if not os.path.exists(datadir):
|
||||
path = os.path.join(self.root, self.FILES[0])
|
||||
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
||||
if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]:
|
||||
import academictorrents as at
|
||||
|
||||
atpath = at.get(self.AT_HASH, datastore=self.root)
|
||||
assert atpath == path
|
||||
|
||||
@@ -242,7 +247,7 @@ class ImageNetValidation(ImageNetBase):
|
||||
tar.extractall(path=datadir)
|
||||
|
||||
vspath = os.path.join(self.root, self.FILES[1])
|
||||
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
|
||||
if not os.path.exists(vspath) or not os.path.getsize(vspath) == self.SIZES[1]:
|
||||
download(self.VS_URL, vspath)
|
||||
|
||||
with open(vspath, "r") as f:
|
||||
@@ -261,18 +266,15 @@ class ImageNetValidation(ImageNetBase):
|
||||
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
||||
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
||||
filelist = sorted(filelist)
|
||||
filelist = "\n".join(filelist)+"\n"
|
||||
filelist = "\n".join(filelist) + "\n"
|
||||
with open(self.txt_filelist, "w") as f:
|
||||
f.write(filelist)
|
||||
|
||||
tdu.mark_prepared(self.root)
|
||||
|
||||
|
||||
|
||||
class ImageNetSR(Dataset):
|
||||
def __init__(self, size=None,
|
||||
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
|
||||
random_crop=True):
|
||||
def __init__(self, size=None, degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.0, random_crop=True):
|
||||
"""
|
||||
Imagenet Superresolution Dataloader
|
||||
Performs following ops in order:
|
||||
@@ -296,12 +298,12 @@ class ImageNetSR(Dataset):
|
||||
self.LR_size = int(size / downscale_f)
|
||||
self.min_crop_f = min_crop_f
|
||||
self.max_crop_f = max_crop_f
|
||||
assert(max_crop_f <= 1.)
|
||||
assert max_crop_f <= 1.0
|
||||
self.center_crop = not random_crop
|
||||
|
||||
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
|
||||
|
||||
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
|
||||
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
|
||||
|
||||
if degradation == "bsrgan":
|
||||
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
|
||||
@@ -311,17 +313,17 @@ class ImageNetSR(Dataset):
|
||||
|
||||
else:
|
||||
interpolation_fn = {
|
||||
"cv_nearest": cv2.INTER_NEAREST,
|
||||
"cv_bilinear": cv2.INTER_LINEAR,
|
||||
"cv_bicubic": cv2.INTER_CUBIC,
|
||||
"cv_area": cv2.INTER_AREA,
|
||||
"cv_lanczos": cv2.INTER_LANCZOS4,
|
||||
"pil_nearest": PIL.Image.NEAREST,
|
||||
"pil_bilinear": PIL.Image.BILINEAR,
|
||||
"pil_bicubic": PIL.Image.BICUBIC,
|
||||
"pil_box": PIL.Image.BOX,
|
||||
"pil_hamming": PIL.Image.HAMMING,
|
||||
"pil_lanczos": PIL.Image.LANCZOS,
|
||||
"cv_nearest": cv2.INTER_NEAREST,
|
||||
"cv_bilinear": cv2.INTER_LINEAR,
|
||||
"cv_bicubic": cv2.INTER_CUBIC,
|
||||
"cv_area": cv2.INTER_AREA,
|
||||
"cv_lanczos": cv2.INTER_LANCZOS4,
|
||||
"pil_nearest": PIL.Image.NEAREST,
|
||||
"pil_bilinear": PIL.Image.BILINEAR,
|
||||
"pil_bicubic": PIL.Image.BICUBIC,
|
||||
"pil_box": PIL.Image.BOX,
|
||||
"pil_hamming": PIL.Image.HAMMING,
|
||||
"pil_lanczos": PIL.Image.LANCZOS,
|
||||
}[degradation]
|
||||
|
||||
self.pil_interpolation = degradation.startswith("pil_")
|
||||
@@ -330,8 +332,9 @@ class ImageNetSR(Dataset):
|
||||
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
|
||||
|
||||
else:
|
||||
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
|
||||
interpolation=interpolation_fn)
|
||||
self.degradation_process = albumentations.SmallestMaxSize(
|
||||
max_size=self.LR_size, interpolation=interpolation_fn
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.base)
|
||||
@@ -366,8 +369,8 @@ class ImageNetSR(Dataset):
|
||||
else:
|
||||
LR_image = self.degradation_process(image=image)["image"]
|
||||
|
||||
example["image"] = (image/127.5 - 1.0).astype(np.float32)
|
||||
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
|
||||
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
example["LR_image"] = (LR_image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
return example
|
||||
|
||||
@@ -379,7 +382,9 @@ class ImageNetSRTrain(ImageNetSR):
|
||||
def get_base(self):
|
||||
with open("data/imagenet_train_hr_indices.p", "rb") as f:
|
||||
indices = pickle.load(f)
|
||||
dset = ImageNetTrain(process_images=False,)
|
||||
dset = ImageNetTrain(
|
||||
process_images=False,
|
||||
)
|
||||
return Subset(dset, indices)
|
||||
|
||||
|
||||
@@ -390,5 +395,7 @@ class ImageNetSRValidation(ImageNetSR):
|
||||
def get_base(self):
|
||||
with open("data/imagenet_val_hr_indices.p", "rb") as f:
|
||||
indices = pickle.load(f)
|
||||
dset = ImageNetValidation(process_images=False,)
|
||||
dset = ImageNetValidation(
|
||||
process_images=False,
|
||||
)
|
||||
return Subset(dset, indices)
|
||||
|
@@ -1,47 +1,49 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
# This class is used to create a dataset of images from LSUN dataset for training
|
||||
class LSUNBase(Dataset):
|
||||
def __init__(self,
|
||||
txt_file, # path to the text file containing the list of image paths
|
||||
data_root, # root directory of the LSUN dataset
|
||||
size=None, # the size of images to resize to
|
||||
interpolation="bicubic", # interpolation method to be used while resizing
|
||||
flip_p=0.5 # probability of random horizontal flipping
|
||||
):
|
||||
self.data_paths = txt_file # store path to text file containing list of images
|
||||
self.data_root = data_root # store path to root directory of the dataset
|
||||
with open(self.data_paths, "r") as f: # open and read the text file
|
||||
self.image_paths = f.read().splitlines() # read the lines of the file and store as list
|
||||
self._length = len(self.image_paths) # store the number of images
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
txt_file, # path to the text file containing the list of image paths
|
||||
data_root, # root directory of the LSUN dataset
|
||||
size=None, # the size of images to resize to
|
||||
interpolation="bicubic", # interpolation method to be used while resizing
|
||||
flip_p=0.5, # probability of random horizontal flipping
|
||||
):
|
||||
self.data_paths = txt_file # store path to text file containing list of images
|
||||
self.data_root = data_root # store path to root directory of the dataset
|
||||
with open(self.data_paths, "r") as f: # open and read the text file
|
||||
self.image_paths = f.read().splitlines() # read the lines of the file and store as list
|
||||
self._length = len(self.image_paths) # store the number of images
|
||||
|
||||
# create dictionary to hold image path information
|
||||
self.labels = {
|
||||
"relative_file_path_": [l for l in self.image_paths],
|
||||
"file_path_": [os.path.join(self.data_root, l)
|
||||
for l in self.image_paths],
|
||||
"file_path_": [os.path.join(self.data_root, l) for l in self.image_paths],
|
||||
}
|
||||
|
||||
# set the image size to be resized
|
||||
self.size = size
|
||||
self.size = size
|
||||
# set the interpolation method for resizing the image
|
||||
self.interpolation = {"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
self.interpolation = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
# randomly flip the image horizontally with a given probability
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
def __len__(self):
|
||||
# return the length of dataset
|
||||
return self._length
|
||||
|
||||
|
||||
def __getitem__(self, i):
|
||||
# get the image path for the given index
|
||||
@@ -52,59 +54,71 @@ class LSUNBase(Dataset):
|
||||
image = image.convert("RGB")
|
||||
|
||||
# default to score-sde preprocessing
|
||||
|
||||
img = np.array(image).astype(np.uint8) # convert image to numpy array
|
||||
crop = min(img.shape[0], img.shape[1]) # crop the image to a square shape
|
||||
h, w, = img.shape[0], img.shape[1] # get the height and width of image
|
||||
img = img[(h - crop) // 2:(h + crop) // 2,
|
||||
(w - crop) // 2:(w + crop) // 2] # crop the image to a square shape
|
||||
|
||||
image = Image.fromarray(img) # create an image from numpy array
|
||||
if self.size is not None: # if image size is provided, resize the image
|
||||
img = np.array(image).astype(np.uint8) # convert image to numpy array
|
||||
crop = min(img.shape[0], img.shape[1]) # crop the image to a square shape
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
) # get the height and width of image
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2
|
||||
] # crop the image to a square shape
|
||||
|
||||
image = Image.fromarray(img) # create an image from numpy array
|
||||
if self.size is not None: # if image size is provided, resize the image
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
|
||||
image = self.flip(image) # flip the image horizontally with the given probability
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = self.flip(image) # flip the image horizontally with the given probability
|
||||
image = np.array(image).astype(np.uint8)
|
||||
example["image"] = (image / 127.5 - 1.0).astype(np.float32) # normalize the image values and convert to float32
|
||||
return example # return the example dictionary containing the image and its file paths
|
||||
return example # return the example dictionary containing the image and its file paths
|
||||
|
||||
#A dataset class for LSUN Churches training set.
|
||||
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
|
||||
|
||||
# A dataset class for LSUN Churches training set.
|
||||
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
|
||||
# The text file containing the paths to the images and the root directory where the images are stored are passed as arguments. Any additional keyword arguments passed to this class will be forwarded to the constructor of the parent class.
|
||||
class LSUNChurchesTrain(LSUNBase):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
|
||||
|
||||
#A dataset class for LSUN Churches validation set.
|
||||
|
||||
# A dataset class for LSUN Churches validation set.
|
||||
# It is similar to LSUNChurchesTrain except that it uses a different text file and sets the flip probability to zero by default.
|
||||
class LSUNChurchesValidation(LSUNBase):
|
||||
def __init__(self, flip_p=0., **kwargs):
|
||||
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
|
||||
flip_p=flip_p, **kwargs)
|
||||
def __init__(self, flip_p=0.0, **kwargs):
|
||||
super().__init__(
|
||||
txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", flip_p=flip_p, **kwargs
|
||||
)
|
||||
|
||||
# A dataset class for LSUN Bedrooms training set.
|
||||
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
|
||||
|
||||
# A dataset class for LSUN Bedrooms training set.
|
||||
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
|
||||
class LSUNBedroomsTrain(LSUNBase):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
|
||||
|
||||
# A dataset class for LSUN Bedrooms validation set.
|
||||
|
||||
# A dataset class for LSUN Bedrooms validation set.
|
||||
# It is similar to LSUNBedroomsTrain except that it uses a different text file and sets the flip probability to zero by default.
|
||||
class LSUNBedroomsValidation(LSUNBase):
|
||||
def __init__(self, flip_p=0.0, **kwargs):
|
||||
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
|
||||
flip_p=flip_p, **kwargs)
|
||||
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", flip_p=flip_p, **kwargs)
|
||||
|
||||
# A dataset class for LSUN Cats training set.
|
||||
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
|
||||
|
||||
# A dataset class for LSUN Cats training set.
|
||||
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
|
||||
# The text file containing the paths to the images and the root directory where the images are stored are passed as arguments.
|
||||
class LSUNCatsTrain(LSUNBase):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
|
||||
|
||||
# A dataset class for LSUN Cats validation set.
|
||||
|
||||
# A dataset class for LSUN Cats validation set.
|
||||
# It is similar to LSUNCatsTrain except that it uses a different text file and sets the flip probability to zero by default.
|
||||
class LSUNCatsValidation(LSUNBase):
|
||||
def __init__(self, flip_p=0., **kwargs):
|
||||
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
|
||||
flip_p=flip_p, **kwargs)
|
||||
def __init__(self, flip_p=0.0, **kwargs):
|
||||
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", flip_p=flip_p, **kwargs)
|
||||
|
@@ -1,15 +1,16 @@
|
||||
from typing import Dict
|
||||
import numpy as np
|
||||
from omegaconf import DictConfig, ListConfig
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from pathlib import Path
|
||||
import json
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from einops import rearrange
|
||||
from ldm.util import instantiate_from_config
|
||||
from datasets import load_dataset
|
||||
from omegaconf import DictConfig, ListConfig
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def make_multi_folder_data(paths, caption_files=None, **kwargs):
|
||||
"""Make a concat dataset from multiple folders
|
||||
@@ -19,10 +20,9 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs):
|
||||
"""
|
||||
list_of_paths = []
|
||||
if isinstance(paths, (Dict, DictConfig)):
|
||||
assert caption_files is None, \
|
||||
"Caption files not yet supported for repeats"
|
||||
assert caption_files is None, "Caption files not yet supported for repeats"
|
||||
for folder_path, repeats in paths.items():
|
||||
list_of_paths.extend([folder_path]*repeats)
|
||||
list_of_paths.extend([folder_path] * repeats)
|
||||
paths = list_of_paths
|
||||
|
||||
if caption_files is not None:
|
||||
@@ -31,8 +31,10 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs):
|
||||
datasets = [FolderData(p, **kwargs) for p in paths]
|
||||
return torch.utils.data.ConcatDataset(datasets)
|
||||
|
||||
|
||||
class FolderData(Dataset):
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
root_dir,
|
||||
caption_file=None,
|
||||
image_transforms=[],
|
||||
@@ -40,7 +42,7 @@ class FolderData(Dataset):
|
||||
default_caption="",
|
||||
postprocess=None,
|
||||
return_paths=False,
|
||||
) -> None:
|
||||
) -> None:
|
||||
"""Create a dataset from a folder of images.
|
||||
If you pass in a root directory it will be searched for images
|
||||
ending in ext (ext can be a list)
|
||||
@@ -75,12 +77,12 @@ class FolderData(Dataset):
|
||||
self.paths.extend(list(self.root_dir.rglob(f"*.{e}")))
|
||||
if isinstance(image_transforms, ListConfig):
|
||||
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||
image_transforms.extend([transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
||||
image_transforms.extend(
|
||||
[transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))]
|
||||
)
|
||||
image_transforms = transforms.Compose(image_transforms)
|
||||
self.tform = image_transforms
|
||||
|
||||
|
||||
def __len__(self):
|
||||
if self.captions is not None:
|
||||
return len(self.captions.keys())
|
||||
@@ -94,7 +96,7 @@ class FolderData(Dataset):
|
||||
caption = self.captions.get(chosen, None)
|
||||
if caption is None:
|
||||
caption = self.default_caption
|
||||
filename = self.root_dir/chosen
|
||||
filename = self.root_dir / chosen
|
||||
else:
|
||||
filename = self.paths[index]
|
||||
|
||||
@@ -119,23 +121,26 @@ class FolderData(Dataset):
|
||||
im = im.convert("RGB")
|
||||
return self.tform(im)
|
||||
|
||||
|
||||
def hf_dataset(
|
||||
path = "Fazzie/Teyvat",
|
||||
path="Fazzie/Teyvat",
|
||||
image_transforms=[],
|
||||
image_column="image",
|
||||
text_column="text",
|
||||
image_key='image',
|
||||
caption_key='txt',
|
||||
):
|
||||
"""Make huggingface dataset with appropriate list of transforms applied
|
||||
"""
|
||||
image_key="image",
|
||||
caption_key="txt",
|
||||
):
|
||||
"""Make huggingface dataset with appropriate list of transforms applied"""
|
||||
ds = load_dataset(path, name="train")
|
||||
ds = ds["train"]
|
||||
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||
image_transforms.extend([transforms.Resize((256, 256)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]
|
||||
)
|
||||
image_transforms.extend(
|
||||
[
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c")),
|
||||
]
|
||||
)
|
||||
tform = transforms.Compose(image_transforms)
|
||||
|
||||
assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
|
||||
@@ -149,4 +154,4 @@ def hf_dataset(
|
||||
return processed
|
||||
|
||||
ds.set_transform(pre_process)
|
||||
return ds
|
||||
return ds
|
||||
|
Reference in New Issue
Block a user