mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 21:22:49 +00:00
[feat] Dist Loader for Eval (#5950)
* support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tp error * remove unused parameters * remove unused * update inference * update docs * update inference --------- Co-authored-by: Michelle <qianranma8@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -197,9 +197,7 @@ class AGIEvalDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
|
||||
files = glob.glob(os.path.join(path, "*.jsonl"))
|
||||
|
@@ -1,6 +1,9 @@
|
||||
from abc import abstractstaticmethod
|
||||
|
||||
from colossal_eval.utils import jdump
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
|
||||
class BaseDataset:
|
||||
@@ -12,13 +15,24 @@ class BaseDataset:
|
||||
logger: Logger for the dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, path, logger, few_shot, forward_only=False, load_train=False, load_reference=False):
|
||||
self.dataset = self.load(path, logger, few_shot, forward_only, load_train, load_reference)
|
||||
def __init__(self, path, logger, *args, **kwargs):
|
||||
self.dataset = self.load(path, logger, *args, **kwargs)
|
||||
|
||||
def save(self, save_path):
|
||||
"""Save the converted dataset"""
|
||||
jdump(self.dataset, save_path)
|
||||
|
||||
@abstractstaticmethod
|
||||
def load(path, logger):
|
||||
def load(path, logger: DistributedLogger, *args, **kwargs):
|
||||
"""Load the original dataset and convert it into the inference dataset"""
|
||||
|
||||
|
||||
class DistributedDataset(Dataset):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
@@ -90,9 +90,7 @@ class CEvalDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
for split in ["dev", "test"]:
|
||||
files = os.listdir(os.path.join(path, split))
|
||||
|
@@ -101,9 +101,7 @@ class CMMLUDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
for split in ["dev", "test"]:
|
||||
files = os.listdir(os.path.join(path, split))
|
||||
|
@@ -37,7 +37,7 @@ class ColossalDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
data = jload(path)
|
||||
data_per_category = get_data_per_category(data)
|
||||
|
@@ -28,7 +28,7 @@ class CValuesDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
file_path = os.path.join(path, "cvalues_responsibility_mc.jsonl")
|
||||
data_list = []
|
||||
|
@@ -69,9 +69,7 @@ class GaoKaoBenchDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]:
|
||||
files = os.listdir(os.path.join(path, "data", category))
|
||||
|
@@ -77,7 +77,7 @@ class LongBenchDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
|
||||
files = os.listdir(path)
|
||||
|
@@ -31,9 +31,7 @@ class MMLUDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
for split in ["dev", "test"]:
|
||||
files = os.listdir(os.path.join(path, split))
|
||||
|
@@ -27,12 +27,12 @@ class MTBenchDataset(BaseDataset):
|
||||
This dataset class will convert the original dataset into the inference dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, path, logger, few_shot):
|
||||
def __init__(self, path, logger: DistributedLogger, *args, **kwargs):
|
||||
self.multiturn = True
|
||||
self.dataset = self.load(path, logger, few_shot)
|
||||
self.dataset = self.load(path, logger, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"test": defaultdict(dict)}
|
||||
|
||||
file_path = os.path.join(path, "question.jsonl")
|
||||
|
@@ -130,7 +130,7 @@ class SafetyBenchENDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
data_files = [os.path.join(path, file_name) for file_name in FILES]
|
||||
for file_path in data_files:
|
||||
|
@@ -130,7 +130,7 @@ class SafetyBenchZHDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
data_files = [os.path.join(path, file_name) for file_name in FILES]
|
||||
for file_path in data_files:
|
||||
|
Reference in New Issue
Block a user