refactor evaluation

This commit is contained in:
YeAnbang
2024-07-22 05:57:39 +00:00
parent c5f582f666
commit 12fe8b5858
22 changed files with 309 additions and 1388 deletions

View File

@@ -1,6 +1,5 @@
from typing import Callable
import torch
from torch.utils.data import Dataset
@@ -18,7 +17,7 @@ class DummyLLMDataset(Dataset):
if key in self.gen_fn:
data[key] = self.gen_fn[key]
else:
data[key] = torch.ones(self.seq_len, dtype=torch.long)
data[key] = [1] * self.seq_len
return data
def __len__(self):