mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[example] support Dreamblooth (#2188)
This commit is contained in:
@@ -1,16 +1,18 @@
|
||||
import math
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset
|
||||
|
||||
|
||||
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
'''
|
||||
Define an interface to make the IterableDatasets for text2img data chainable
|
||||
'''
|
||||
|
||||
def __init__(self, file_path: str, rank, world_size):
|
||||
super().__init__()
|
||||
self.file_path = file_path
|
||||
@@ -52,8 +54,7 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
image = torch.from_numpy(image) / 255
|
||||
yield {"caption": txt_, "image":image}
|
||||
|
||||
yield {"txt": txt_, "image": image}
|
||||
|
||||
def _get_file_info(self, file_path):
|
||||
info = \
|
||||
@@ -72,4 +73,4 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
# for _ in enumerate(fin):
|
||||
# info['end'] += 1
|
||||
# self.txt_list = [k.replace('jpg', 'txt') for k in self.file_list]
|
||||
return info
|
||||
return info
|
||||
|
||||
Reference in New Issue
Block a user