mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-25 10:06:27 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			114 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			114 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python
 | |
| # -*- encoding: utf-8 -*-
 | |
| 
 | |
| import copy
 | |
| 
 | |
| import colossalai
 | |
| from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
 | |
| import pytest
 | |
| import torch
 | |
| import torch.distributed as dist
 | |
| import torch.multiprocessing as mp
 | |
| import torch.nn as nn
 | |
| from colossalai.logging import disable_existing_loggers
 | |
| from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
 | |
| from torch.nn.parallel import DistributedDataParallel as DDP
 | |
| from torch.nn.utils import clip_grad_norm_
 | |
| from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
 | |
| from functools import partial
 | |
| from colossalai.testing import parameterize, rerun_if_address_is_in_use
 | |
| 
 | |
| 
 | |
| def checkpoint_wrapper(module, enable=True):
 | |
|     if enable:
 | |
|         module.forward = partial(checkpoint, module.forward, False)
 | |
|     return module
 | |
| 
 | |
| 
 | |
| class Net(nn.Module):
 | |
| 
 | |
|     def __init__(self, checkpoint=False) -> None:
 | |
|         super().__init__()
 | |
|         self.fc1 = nn.Linear(5, 5)
 | |
|         self.fc2 = nn.Linear(5, 5)
 | |
|         self.fc3 = nn.Linear(5, 1)
 | |
|         if checkpoint:
 | |
|             self.fc1 = checkpoint_wrapper(self.fc1)
 | |
|         self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
 | |
| 
 | |
|     def forward(self, x):
 | |
|         for layer in self.layers:
 | |
|             x = layer(x)
 | |
|         return x
 | |
| 
 | |
| 
 | |
| def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0):
 | |
|     model.train()
 | |
|     optimizer.zero_grad()
 | |
|     with torch.cuda.amp.autocast(enabled=enable_autocast):
 | |
|         y = model(x)
 | |
|         loss = y.sum()
 | |
|     loss = loss.float()
 | |
|     loss.backward()
 | |
|     clip_grad(model, norm_type)
 | |
|     optimizer.step()
 | |
| 
 | |
| 
 | |
| def clip_grad(model, norm_type):
 | |
|     if isinstance(model, DDP):
 | |
|         clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type)
 | |
|     else:
 | |
|         clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type)
 | |
| 
 | |
| 
 | |
| def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
 | |
|     if loose:
 | |
|         return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
 | |
|     return torch.allclose(tensor_a, tensor_b)
 | |
| 
 | |
| 
 | |
| def check_grads(model, zero_model, loose=False):
 | |
|     rank = dist.get_rank()
 | |
|     for p, zero_p in zip(model.parameters(), zero_model.parameters()):
 | |
|         zero_grad = zero_p.grad.clone().to(p.device)
 | |
|         chunks = torch.flatten(p.grad).chunk(4)
 | |
|         if rank >= len(chunks):
 | |
|             continue
 | |
|         grad = chunks[rank]
 | |
|         if zero_p.zero_shard_padding > 0:
 | |
|             zero_grad = zero_grad[:-zero_p.zero_shard_padding]
 | |
|         assert grad.dtype == zero_grad.dtype
 | |
|         assert allclose(grad, zero_grad, loose=loose)
 | |
| 
 | |
| 
 | |
| def check_params(model, zero_model, loose=False):
 | |
|     rank = dist.get_rank()
 | |
|     for p, zero_p in zip(model.parameters(), zero_model.parameters()):
 | |
|         zero_shard_padding = zero_p.zero_shard_padding
 | |
|         zero_p = zero_p.clone().to(p.device)
 | |
|         chunks = torch.flatten(p).chunk(4)
 | |
|         if rank >= len(chunks):
 | |
|             continue
 | |
|         p = chunks[rank]
 | |
|         if zero_shard_padding > 0:
 | |
|             zero_p = zero_p[:-zero_shard_padding]
 | |
|         assert p.dtype == zero_p.dtype
 | |
|         assert allclose(p, zero_p, loose=loose)
 | |
| 
 | |
| 
 | |
| def run_dist(rank, world_size, port):
 | |
|     disable_existing_loggers()
 | |
|     colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
 | |
| 
 | |
| 
 | |
| @pytest.mark.dist
 | |
| @rerun_if_address_is_in_use()
 | |
| def test_zero_clip_grad():
 | |
|     world_size = 4
 | |
|     run_func = partial(run_dist, world_size=world_size, port=free_port())
 | |
|     mp.spawn(run_func, nprocs=world_size)
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     test_zero_clip_grad()
 |