mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-30 13:31:12 +00:00 
			
		
		
		
	[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
		| @@ -1,8 +1,7 @@ | ||||
| import torch | ||||
|  | ||||
|  | ||||
| class DummyDataloader(): | ||||
|  | ||||
| class DummyDataloader: | ||||
|     def __init__(self, batch_size, vocab_size, seq_length): | ||||
|         self.batch_size = batch_size | ||||
|         self.vocab_size = vocab_size | ||||
| @@ -10,30 +9,44 @@ class DummyDataloader(): | ||||
|         self.step = 0 | ||||
|  | ||||
|     def generate(self): | ||||
|         tokens = torch.randint(low=0, high=self.vocab_size, size=( | ||||
|             self.batch_size, | ||||
|             self.seq_length, | ||||
|         )) | ||||
|         types = torch.randint(low=0, high=3, size=( | ||||
|             self.batch_size, | ||||
|             self.seq_length, | ||||
|         )) | ||||
|         tokens = torch.randint( | ||||
|             low=0, | ||||
|             high=self.vocab_size, | ||||
|             size=( | ||||
|                 self.batch_size, | ||||
|                 self.seq_length, | ||||
|             ), | ||||
|         ) | ||||
|         types = torch.randint( | ||||
|             low=0, | ||||
|             high=3, | ||||
|             size=( | ||||
|                 self.batch_size, | ||||
|                 self.seq_length, | ||||
|             ), | ||||
|         ) | ||||
|         sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,)) | ||||
|         loss_mask = torch.randint(low=0, high=2, size=( | ||||
|             self.batch_size, | ||||
|             self.seq_length, | ||||
|         )) | ||||
|         loss_mask = torch.randint( | ||||
|             low=0, | ||||
|             high=2, | ||||
|             size=( | ||||
|                 self.batch_size, | ||||
|                 self.seq_length, | ||||
|             ), | ||||
|         ) | ||||
|         lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length)) | ||||
|         padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length)) | ||||
|         return dict(text=tokens, | ||||
|                     types=types, | ||||
|                     is_random=sentence_order, | ||||
|                     loss_mask=loss_mask, | ||||
|                     labels=lm_labels, | ||||
|                     padding_mask=padding_mask) | ||||
|         return dict( | ||||
|             text=tokens, | ||||
|             types=types, | ||||
|             is_random=sentence_order, | ||||
|             loss_mask=loss_mask, | ||||
|             labels=lm_labels, | ||||
|             padding_mask=padding_mask, | ||||
|         ) | ||||
|  | ||||
|     def __iter__(self): | ||||
|         return self | ||||
|  | ||||
|     def __next__(self): | ||||
|         return self.generate() | ||||
|         return self.generate() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user