mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-22 15:26:57 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			34 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			34 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup
 | |
| import torch
 | |
| import pytest
 | |
| from common_utils import tensor_equal
 | |
| import colossalai
 | |
| from colossalai.utils import free_port
 | |
| 
 | |
| 
 | |
| @pytest.mark.skip
 | |
| def test_multiinheritance():
 | |
|     colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
 | |
|     colo_param = ColoParameter(None, requires_grad=True)
 | |
|     assert colo_param.dist_spec.placement.value == 'r'
 | |
|     assert isinstance(colo_param, ColoTensor)
 | |
|     assert isinstance(colo_param, torch.nn.Parameter)
 | |
| 
 | |
|     # __deepcopy__ overload
 | |
|     import copy
 | |
|     colo_param2 = copy.deepcopy(colo_param)
 | |
|     assert isinstance(colo_param2, ColoParameter)
 | |
|     assert tensor_equal(colo_param.data, colo_param2.data)
 | |
|     assert colo_param.requires_grad == colo_param2.requires_grad
 | |
| 
 | |
|     # __repr__ overload
 | |
|     assert 'ColoParameter' in str(colo_param)
 | |
| 
 | |
|     # __torch_function__
 | |
|     clone_param = torch.clone(colo_param)
 | |
|     assert isinstance(clone_param, ColoTensor)
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     test_multiinheritance()
 |