[example] support Dreamblooth (#2188)

This commit is contained in:
Fazzie-Maqianli
2022-12-23 16:47:30 +08:00
committed by GitHub
parent 1cf6d92d7c
commit ce3c4eca7b
11 changed files with 2399 additions and 8 deletions

View File

@@ -1,16 +1,18 @@
import math
import os
from abc import abstractmethod
import torch
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
import os
import numpy as np
import cv2
import numpy as np
import torch
from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
Define an interface to make the IterableDatasets for text2img data chainable
'''
def __init__(self, file_path: str, rank, world_size):
super().__init__()
self.file_path = file_path
@@ -52,8 +54,7 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = torch.from_numpy(image) / 255
yield {"caption": txt_, "image":image}
yield {"txt": txt_, "image": image}
def _get_file_info(self, file_path):
info = \
@@ -72,4 +73,4 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
# for _ in enumerate(fin):
# info['end'] += 1
# self.txt_list = [k.replace('jpg', 'txt') for k in self.file_list]
return info
return info