mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-28 16:28:10 +00:00
* refactor latest code * update api * add dummy dataset * update Readme * add setup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update files * add PP support * update arguments * update argument * reorg folder * update version * remove IB infor * update utils * update readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update save for zero * update save * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add apex * update --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
25 lines
781 B
Python
25 lines
781 B
Python
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
|
|
|
|
class RandomDataset(Dataset):
|
|
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
|
|
self.num_samples = num_samples
|
|
self.max_length = max_length
|
|
self.input_ids = torch.randint(
|
|
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
|
|
)
|
|
self.attention_mask = torch.ones_like(self.input_ids)
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def __getitem__(self, idx):
|
|
return {
|
|
"input_ids": self.input_ids[idx],
|
|
"attention_mask": self.attention_mask[idx],
|
|
"labels": self.input_ids[idx],
|
|
}
|