mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,17 +1,15 @@
|
||||
import math
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
|
||||
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
'''
|
||||
"""
|
||||
Define an interface to make the IterableDatasets for text2img data chainable
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, rank, world_size):
|
||||
super().__init__()
|
||||
@@ -20,8 +18,8 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
self.file_list = []
|
||||
self.txt_list = []
|
||||
self.info = self._get_file_info(file_path)
|
||||
self.start = self.info['start']
|
||||
self.end = self.info['end']
|
||||
self.start = self.info["start"]
|
||||
self.end = self.info["end"]
|
||||
self.rank = rank
|
||||
|
||||
self.world_size = world_size
|
||||
@@ -33,7 +31,7 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
self.num_records = self.end - self.start
|
||||
self.valid_ids = [i for i in range(self.end)]
|
||||
|
||||
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
|
||||
print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.")
|
||||
|
||||
def __len__(self):
|
||||
# return self.iter_end - self.iter_start
|
||||
@@ -48,7 +46,7 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
for idx in range(start, end):
|
||||
file_name = self.file_list[idx]
|
||||
txt_name = self.txt_list[idx]
|
||||
f_ = open(txt_name, 'r')
|
||||
f_ = open(txt_name, "r")
|
||||
txt_ = f_.read()
|
||||
f_.close()
|
||||
image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1)
|
||||
@@ -57,18 +55,17 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
yield {"txt": txt_, "image": image}
|
||||
|
||||
def _get_file_info(self, file_path):
|
||||
info = \
|
||||
{
|
||||
info = {
|
||||
"start": 1,
|
||||
"end": 0,
|
||||
}
|
||||
self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i]
|
||||
self.folder_list = [file_path + i for i in os.listdir(file_path) if "." not in i]
|
||||
for folder in self.folder_list:
|
||||
files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i]
|
||||
txts = [k.replace('jpg', 'txt') for k in files]
|
||||
files = [folder + "/" + i for i in os.listdir(folder) if "jpg" in i]
|
||||
txts = [k.replace("jpg", "txt") for k in files]
|
||||
self.file_list.extend(files)
|
||||
self.txt_list.extend(txts)
|
||||
info['end'] = len(self.file_list)
|
||||
info["end"] = len(self.file_list)
|
||||
# with open(file_path, 'r') as fin:
|
||||
# for _ in enumerate(fin):
|
||||
# info['end'] += 1
|
||||
|
Reference in New Issue
Block a user