mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
add kto
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class DummyLLMDataset(Dataset):
|
||||
def __init__(self, keys, seq_len, size=500):
|
||||
def __init__(self, keys, seq_len, size=500, gen_fn={}):
|
||||
self.keys = keys
|
||||
self.gen_fn = gen_fn
|
||||
self.seq_len = seq_len
|
||||
self.data = self._generate_data()
|
||||
self.size = size
|
||||
@@ -12,11 +15,17 @@ class DummyLLMDataset(Dataset):
|
||||
def _generate_data(self):
|
||||
data = {}
|
||||
for key in self.keys:
|
||||
data[key] = torch.ones(self.seq_len, dtype=torch.long)
|
||||
if key in self.gen_fn:
|
||||
data[key] = self.gen_fn[key]
|
||||
else:
|
||||
data[key] = torch.ones(self.seq_len, dtype=torch.long)
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {key: self.data[key] for key in self.keys}
|
||||
return {
|
||||
key: self.data[key] if not isinstance(self.data[key], Callable) else self.data[key](idx)
|
||||
for key in self.keys
|
||||
}
|
||||
|
Reference in New Issue
Block a user