[feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Tong Li
2024-08-02 10:06:25 +08:00
committed by GitHub
parent 62cdac6b7b
commit 19d1510ea2
15 changed files with 93 additions and 77 deletions

View File

@@ -197,9 +197,7 @@ class AGIEvalDataset(BaseDataset):
"""
@staticmethod
def load(
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
) -> List[Dict]:
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}}
files = glob.glob(os.path.join(path, "*.jsonl"))

View File

@@ -1,6 +1,9 @@
from abc import abstractstaticmethod
from colossal_eval.utils import jdump
from torch.utils.data import Dataset
from colossalai.logging import DistributedLogger
class BaseDataset:
@@ -12,13 +15,24 @@ class BaseDataset:
logger: Logger for the dataset.
"""
def __init__(self, path, logger, few_shot, forward_only=False, load_train=False, load_reference=False):
self.dataset = self.load(path, logger, few_shot, forward_only, load_train, load_reference)
def __init__(self, path, logger, *args, **kwargs):
self.dataset = self.load(path, logger, *args, **kwargs)
def save(self, save_path):
"""Save the converted dataset"""
jdump(self.dataset, save_path)
@abstractstaticmethod
def load(path, logger):
def load(path, logger: DistributedLogger, *args, **kwargs):
"""Load the original dataset and convert it into the inference dataset"""
class DistributedDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]

View File

@@ -90,9 +90,7 @@ class CEvalDataset(BaseDataset):
"""
@staticmethod
def load(
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
) -> List[Dict]:
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
for split in ["dev", "test"]:
files = os.listdir(os.path.join(path, split))

View File

@@ -101,9 +101,7 @@ class CMMLUDataset(BaseDataset):
"""
@staticmethod
def load(
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
) -> List[Dict]:
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
for split in ["dev", "test"]:
files = os.listdir(os.path.join(path, split))

View File

@@ -37,7 +37,7 @@ class ColossalDataset(BaseDataset):
"""
@staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}}
data = jload(path)
data_per_category = get_data_per_category(data)

View File

@@ -28,7 +28,7 @@ class CValuesDataset(BaseDataset):
"""
@staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}}
file_path = os.path.join(path, "cvalues_responsibility_mc.jsonl")
data_list = []

View File

@@ -69,9 +69,7 @@ class GaoKaoBenchDataset(BaseDataset):
"""
@staticmethod
def load(
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
) -> List[Dict]:
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}}
for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]:
files = os.listdir(os.path.join(path, "data", category))

View File

@@ -77,7 +77,7 @@ class LongBenchDataset(BaseDataset):
"""
@staticmethod
def load(path: str, logger: DistributedLogger) -> List[Dict]:
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}}
files = os.listdir(path)

View File

@@ -31,9 +31,7 @@ class MMLUDataset(BaseDataset):
"""
@staticmethod
def load(
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
) -> List[Dict]:
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
for split in ["dev", "test"]:
files = os.listdir(os.path.join(path, split))

View File

@@ -27,12 +27,12 @@ class MTBenchDataset(BaseDataset):
This dataset class will convert the original dataset into the inference dataset.
"""
def __init__(self, path, logger, few_shot):
def __init__(self, path, logger: DistributedLogger, *args, **kwargs):
self.multiturn = True
self.dataset = self.load(path, logger, few_shot)
self.dataset = self.load(path, logger, *args, **kwargs)
@staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": defaultdict(dict)}
file_path = os.path.join(path, "question.jsonl")

View File

@@ -130,7 +130,7 @@ class SafetyBenchENDataset(BaseDataset):
"""
@staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
data_files = [os.path.join(path, file_name) for file_name in FILES]
for file_path in data_files:

View File

@@ -130,7 +130,7 @@ class SafetyBenchZHDataset(BaseDataset):
"""
@staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
data_files = [os.path.join(path, file_name) for file_name in FILES]
for file_path in data_files: