mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
import torch
|
||||
|
||||
|
||||
class DummyDataloader():
|
||||
|
||||
class DummyDataloader:
|
||||
def __init__(self, batch_size, vocab_size, seq_length):
|
||||
self.batch_size = batch_size
|
||||
self.vocab_size = vocab_size
|
||||
@@ -10,30 +9,44 @@ class DummyDataloader():
|
||||
self.step = 0
|
||||
|
||||
def generate(self):
|
||||
tokens = torch.randint(low=0, high=self.vocab_size, size=(
|
||||
self.batch_size,
|
||||
self.seq_length,
|
||||
))
|
||||
types = torch.randint(low=0, high=3, size=(
|
||||
self.batch_size,
|
||||
self.seq_length,
|
||||
))
|
||||
tokens = torch.randint(
|
||||
low=0,
|
||||
high=self.vocab_size,
|
||||
size=(
|
||||
self.batch_size,
|
||||
self.seq_length,
|
||||
),
|
||||
)
|
||||
types = torch.randint(
|
||||
low=0,
|
||||
high=3,
|
||||
size=(
|
||||
self.batch_size,
|
||||
self.seq_length,
|
||||
),
|
||||
)
|
||||
sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,))
|
||||
loss_mask = torch.randint(low=0, high=2, size=(
|
||||
self.batch_size,
|
||||
self.seq_length,
|
||||
))
|
||||
loss_mask = torch.randint(
|
||||
low=0,
|
||||
high=2,
|
||||
size=(
|
||||
self.batch_size,
|
||||
self.seq_length,
|
||||
),
|
||||
)
|
||||
lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length))
|
||||
padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length))
|
||||
return dict(text=tokens,
|
||||
types=types,
|
||||
is_random=sentence_order,
|
||||
loss_mask=loss_mask,
|
||||
labels=lm_labels,
|
||||
padding_mask=padding_mask)
|
||||
return dict(
|
||||
text=tokens,
|
||||
types=types,
|
||||
is_random=sentence_order,
|
||||
loss_mask=loss_mask,
|
||||
labels=lm_labels,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return self.generate()
|
||||
return self.generate()
|
||||
|
Reference in New Issue
Block a user