mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-30 21:39:05 +00:00 
			
		
		
		
	[tutorial] added missing dummy dataloader (#1944)
This commit is contained in:
		
							
								
								
									
										2
									
								
								examples/tutorial/.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								examples/tutorial/.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1 +1 @@ | |||||||
| data/ | ./data/ | ||||||
|   | |||||||
							
								
								
									
										39
									
								
								examples/tutorial/sequence_parallel/data/dummy_dataloader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								examples/tutorial/sequence_parallel/data/dummy_dataloader.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | |||||||
|  | import torch | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DummyDataloader(): | ||||||
|  |  | ||||||
|  |     def __init__(self, batch_size, vocab_size, seq_length): | ||||||
|  |         self.batch_size = batch_size | ||||||
|  |         self.vocab_size = vocab_size | ||||||
|  |         self.seq_length = seq_length | ||||||
|  |         self.step = 0 | ||||||
|  |  | ||||||
|  |     def generate(self): | ||||||
|  |         tokens = torch.randint(low=0, high=self.vocab_size, size=( | ||||||
|  |             self.batch_size, | ||||||
|  |             self.seq_length, | ||||||
|  |         )) | ||||||
|  |         types = torch.randint(low=0, high=3, size=( | ||||||
|  |             self.batch_size, | ||||||
|  |             self.seq_length, | ||||||
|  |         )) | ||||||
|  |         sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,)) | ||||||
|  |         loss_mask = torch.randint(low=0, high=2, size=( | ||||||
|  |             self.batch_size, | ||||||
|  |             self.seq_length, | ||||||
|  |         )) | ||||||
|  |         lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length)) | ||||||
|  |         padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length)) | ||||||
|  |         return dict(text=tokens, | ||||||
|  |                     types=types, | ||||||
|  |                     is_random=sentence_order, | ||||||
|  |                     loss_mask=loss_mask, | ||||||
|  |                     labels=lm_labels, | ||||||
|  |                     padding_mask=padding_mask) | ||||||
|  |  | ||||||
|  |     def __iter__(self): | ||||||
|  |         return self | ||||||
|  |  | ||||||
|  |     def __next__(self): | ||||||
|  |         return self.generate() | ||||||
		Reference in New Issue
	
	Block a user