This commit is contained in:
YeAnbang
2024-07-18 07:54:11 +00:00
parent b3594d4d68
commit 09d5ffca1a
27 changed files with 1739 additions and 63 deletions

View File

@@ -1,10 +1,13 @@
from typing import Callable
import torch
from torch.utils.data import Dataset
class DummyLLMDataset(Dataset):
def __init__(self, keys, seq_len, size=500):
def __init__(self, keys, seq_len, size=500, gen_fn={}):
self.keys = keys
self.gen_fn = gen_fn
self.seq_len = seq_len
self.data = self._generate_data()
self.size = size
@@ -12,11 +15,17 @@ class DummyLLMDataset(Dataset):
def _generate_data(self):
data = {}
for key in self.keys:
data[key] = torch.ones(self.seq_len, dtype=torch.long)
if key in self.gen_fn:
data[key] = self.gen_fn[key]
else:
data[key] = torch.ones(self.seq_len, dtype=torch.long)
return data
def __len__(self):
return self.size
def __getitem__(self, idx):
return {key: self.data[key] for key in self.keys}
return {
key: self.data[key] if not isinstance(self.data[key], Callable) else self.data[key](idx)
for key in self.keys
}