mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-30 21:39:05 +00:00 
			
		
		
		
	* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
		
			
				
	
	
		
			53 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			53 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| 
 | |
| 
 | |
| class DummyDataloader:
 | |
|     def __init__(self, batch_size, vocab_size, seq_length):
 | |
|         self.batch_size = batch_size
 | |
|         self.vocab_size = vocab_size
 | |
|         self.seq_length = seq_length
 | |
|         self.step = 0
 | |
| 
 | |
|     def generate(self):
 | |
|         tokens = torch.randint(
 | |
|             low=0,
 | |
|             high=self.vocab_size,
 | |
|             size=(
 | |
|                 self.batch_size,
 | |
|                 self.seq_length,
 | |
|             ),
 | |
|         )
 | |
|         types = torch.randint(
 | |
|             low=0,
 | |
|             high=3,
 | |
|             size=(
 | |
|                 self.batch_size,
 | |
|                 self.seq_length,
 | |
|             ),
 | |
|         )
 | |
|         sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,))
 | |
|         loss_mask = torch.randint(
 | |
|             low=0,
 | |
|             high=2,
 | |
|             size=(
 | |
|                 self.batch_size,
 | |
|                 self.seq_length,
 | |
|             ),
 | |
|         )
 | |
|         lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length))
 | |
|         padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length))
 | |
|         return dict(
 | |
|             text=tokens,
 | |
|             types=types,
 | |
|             is_random=sentence_order,
 | |
|             loss_mask=loss_mask,
 | |
|             labels=lm_labels,
 | |
|             padding_mask=padding_mask,
 | |
|         )
 | |
| 
 | |
|     def __iter__(self):
 | |
|         return self
 | |
| 
 | |
|     def __next__(self):
 | |
|         return self.generate()
 |