add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

This commit is contained in:
YeAnbang
2024-07-10 10:17:08 +00:00
parent 16f3451fe2
commit d888c3787c
13 changed files with 1175 additions and 26 deletions

View File

@@ -0,0 +1,21 @@
import torch
from torch.utils.data import Dataset, DataLoader
class DummyLLMDataset(Dataset):
def __init__(self, keys, seq_len, size=500):
self.keys = keys
self.seq_len = seq_len
self.data = self._generate_data()
self.size = size
def _generate_data(self):
data = {}
for key in self.keys:
data[key] = torch.ones(self.seq_len, dtype = torch.long)
return data
def __len__(self):
return self.size
def __getitem__(self, idx):
return {key: self.data[key] for key in self.keys}