mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[example] update ViT example using booster api (#3940)
This commit is contained in:
32
examples/images/vit/data.py
Normal file
32
examples/images/vit/data.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from datasets import load_dataset
|
||||
|
||||
class BeansDataset(Dataset):
|
||||
|
||||
def __init__(self, image_processor, split='train'):
|
||||
|
||||
super().__init__()
|
||||
self.image_processor = image_processor
|
||||
self.ds = load_dataset('beans')[split]
|
||||
self.label_names = self.ds.features['labels'].names
|
||||
self.num_labels = len(self.label_names)
|
||||
self.inputs = []
|
||||
for example in self.ds:
|
||||
self.inputs.append(self.process_example(example))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.inputs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.inputs[idx]
|
||||
|
||||
def process_example(self, example):
|
||||
input = self.image_processor(example['image'], return_tensors='pt')
|
||||
input['labels'] = example['labels']
|
||||
return input
|
||||
|
||||
|
||||
def beans_collator(batch):
|
||||
return {'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0),
|
||||
'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)}
|
Reference in New Issue
Block a user