mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-29 12:15:39 +00:00
[tutorial] added missing dummy dataloader (#1944)
This commit is contained in:
parent
c6ea65011f
commit
de56b563b9
2
examples/tutorial/.gitignore
vendored
2
examples/tutorial/.gitignore
vendored
@ -1 +1 @@
|
|||||||
data/
|
./data/
|
||||||
|
39
examples/tutorial/sequence_parallel/data/dummy_dataloader.py
Normal file
39
examples/tutorial/sequence_parallel/data/dummy_dataloader.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDataloader():
|
||||||
|
|
||||||
|
def __init__(self, batch_size, vocab_size, seq_length):
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.step = 0
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
tokens = torch.randint(low=0, high=self.vocab_size, size=(
|
||||||
|
self.batch_size,
|
||||||
|
self.seq_length,
|
||||||
|
))
|
||||||
|
types = torch.randint(low=0, high=3, size=(
|
||||||
|
self.batch_size,
|
||||||
|
self.seq_length,
|
||||||
|
))
|
||||||
|
sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,))
|
||||||
|
loss_mask = torch.randint(low=0, high=2, size=(
|
||||||
|
self.batch_size,
|
||||||
|
self.seq_length,
|
||||||
|
))
|
||||||
|
lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length))
|
||||||
|
padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length))
|
||||||
|
return dict(text=tokens,
|
||||||
|
types=types,
|
||||||
|
is_random=sentence_order,
|
||||||
|
loss_mask=loss_mask,
|
||||||
|
labels=lm_labels,
|
||||||
|
padding_mask=padding_mask)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
return self.generate()
|
Loading…
Reference in New Issue
Block a user