mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-25 10:06:27 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			73 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			73 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python
 | |
| # -*- encoding: utf-8 -*-
 | |
| 
 | |
| from functools import partial
 | |
| 
 | |
| import colossalai
 | |
| import pytest
 | |
| import torch
 | |
| import torch.multiprocessing as mp
 | |
| from colossalai.testing import parameterize, rerun_on_exception
 | |
| from colossalai.utils import free_port
 | |
| from colossalai.zero.init_ctx import ZeroInitContext
 | |
| from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
 | |
| from colossalai.zero.sharded_model import ShardedModelV2
 | |
| from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
 | |
| from colossalai.zero.sharded_model.utils import col_model_deepcopy
 | |
| from tests.components_to_test.registry import non_distributed_component_funcs
 | |
| from torch.nn.parallel import DistributedDataParallel as DDP
 | |
| 
 | |
| from common import CONFIG, check_grads_padding, run_fwd_bwd
 | |
| 
 | |
| 
 | |
| @parameterize("enable_autocast", [True])
 | |
| @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
 | |
| def run_model_test(enable_autocast, shard_strategy_class):
 | |
|     test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
 | |
|     shard_strategy = shard_strategy_class()
 | |
|     for model_name in test_models:
 | |
|         get_components_func = non_distributed_component_funcs.get_callable(model_name)
 | |
|         model_builder, train_dataloader, _, _, criterion = get_components_func()
 | |
| 
 | |
|         rm_torch_payload_on_the_fly = False
 | |
| 
 | |
|         with ZeroInitContext(target_device=torch.cuda.current_device(),
 | |
|                              shard_strategy=shard_strategy,
 | |
|                              shard_param=True,
 | |
|                              rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
 | |
|             zero_model = model_builder(checkpoint=True)
 | |
|         zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
 | |
| 
 | |
|         model = model_builder(checkpoint=True).half()
 | |
|         col_model_deepcopy(zero_model, model)
 | |
|         model = model.cuda()
 | |
| 
 | |
|         model = DDP(model)
 | |
| 
 | |
|         for i, (data, label) in enumerate(train_dataloader):
 | |
|             if i > 5:
 | |
|                 break
 | |
| 
 | |
|             data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
 | |
|             run_fwd_bwd(model, data, label, criterion, enable_autocast)
 | |
|             run_fwd_bwd(zero_model, data, label, criterion, enable_autocast)
 | |
| 
 | |
|             check_grads_padding(model, zero_model, loose=True)
 | |
| 
 | |
| 
 | |
| def run_dist(rank, world_size, port):
 | |
|     colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
 | |
|     run_model_test()
 | |
| 
 | |
| 
 | |
| @pytest.mark.dist
 | |
| @pytest.mark.parametrize("world_size", [1, 2])
 | |
| @rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
 | |
| def test_shard_model_v2(world_size):
 | |
|     run_func = partial(run_dist, world_size=world_size, port=free_port())
 | |
|     mp.spawn(run_func, nprocs=world_size)
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     test_shard_model_v2(world_size=2)
 |