mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -4,13 +4,11 @@ from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class BeansDataset(Dataset):
|
||||
|
||||
def __init__(self, image_processor, tp_size=1, split='train'):
|
||||
|
||||
def __init__(self, image_processor, tp_size=1, split="train"):
|
||||
super().__init__()
|
||||
self.image_processor = image_processor
|
||||
self.ds = load_dataset('beans')[split]
|
||||
self.label_names = self.ds.features['labels'].names
|
||||
self.ds = load_dataset("beans")[split]
|
||||
self.label_names = self.ds.features["labels"].names
|
||||
while len(self.label_names) % tp_size != 0:
|
||||
# ensure that the number of labels is multiple of tp_size
|
||||
self.label_names.append(f"pad_label_{len(self.label_names)}")
|
||||
@@ -26,13 +24,13 @@ class BeansDataset(Dataset):
|
||||
return self.inputs[idx]
|
||||
|
||||
def process_example(self, example):
|
||||
input = self.image_processor(example['image'], return_tensors='pt')
|
||||
input['labels'] = example['labels']
|
||||
input = self.image_processor(example["image"], return_tensors="pt")
|
||||
input["labels"] = example["labels"]
|
||||
return input
|
||||
|
||||
|
||||
def beans_collator(batch):
|
||||
return {
|
||||
'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0),
|
||||
'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)
|
||||
"pixel_values": torch.cat([data["pixel_values"] for data in batch], dim=0),
|
||||
"labels": torch.tensor([data["labels"] for data in batch], dtype=torch.int64),
|
||||
}
|
||||
|
Reference in New Issue
Block a user