mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[feat] Dist Loader for Eval (#5950)
* support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tp error * remove unused parameters * remove unused * update inference * update docs * update inference --------- Co-authored-by: Michelle <qianranma8@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
from abc import abstractstaticmethod
|
||||
|
||||
from colossal_eval.utils import jdump
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
|
||||
class BaseDataset:
|
||||
@@ -12,13 +15,24 @@ class BaseDataset:
|
||||
logger: Logger for the dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, path, logger, few_shot, forward_only=False, load_train=False, load_reference=False):
|
||||
self.dataset = self.load(path, logger, few_shot, forward_only, load_train, load_reference)
|
||||
def __init__(self, path, logger, *args, **kwargs):
|
||||
self.dataset = self.load(path, logger, *args, **kwargs)
|
||||
|
||||
def save(self, save_path):
|
||||
"""Save the converted dataset"""
|
||||
jdump(self.dataset, save_path)
|
||||
|
||||
@abstractstaticmethod
|
||||
def load(path, logger):
|
||||
def load(path, logger: DistributedLogger, *args, **kwargs):
|
||||
"""Load the original dataset and convert it into the inference dataset"""
|
||||
|
||||
|
||||
class DistributedDataset(Dataset):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
Reference in New Issue
Block a user