mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-30 21:39:05 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			27 lines
		
	
	
		
			771 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			27 lines
		
	
	
		
			771 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| 
 | |
| from colossalai.context import ParallelMode
 | |
| from colossalai.core import global_context as gpc
 | |
| from colossalai.nn import TransformerSelfAttentionRing
 | |
| from colossalai.utils import get_current_device
 | |
| 
 | |
| 
 | |
| def check_selfattention():
 | |
|     WORLD_SIZE = gpc.get_world_size(ParallelMode.SEQUENCE)
 | |
|     SUB_SEQ_LENGTH = 8
 | |
|     BATCH = 4
 | |
|     HIDDEN_SIZE = 16
 | |
| 
 | |
|     layer = TransformerSelfAttentionRing(
 | |
|         16,
 | |
|         8,
 | |
|         8,
 | |
|         0.1
 | |
|     )
 | |
|     layer = layer.to(get_current_device())
 | |
| 
 | |
|     hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device())
 | |
|     attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(
 | |
|         get_current_device())
 | |
|     out = layer(hidden_states, attention_mask)
 |