mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-30 21:39:05 +00:00 
			
		
		
		
	* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
		
			
				
	
	
		
			48 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			48 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| 
 | |
| import torch
 | |
| import torchvision
 | |
| import torchvision.transforms as transforms
 | |
| 
 | |
| # ==============================
 | |
| # Parse Arguments
 | |
| # ==============================
 | |
| parser = argparse.ArgumentParser()
 | |
| parser.add_argument("-e", "--epoch", type=int, default=80, help="resume from the epoch's checkpoint")
 | |
| parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory")
 | |
| args = parser.parse_args()
 | |
| 
 | |
| # ==============================
 | |
| # Prepare Test Dataset
 | |
| # ==============================
 | |
| # CIFAR-10 dataset
 | |
| test_dataset = torchvision.datasets.CIFAR10(root="./data/", train=False, transform=transforms.ToTensor())
 | |
| 
 | |
| # Data loader
 | |
| test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)
 | |
| 
 | |
| # ==============================
 | |
| # Load Model
 | |
| # ==============================
 | |
| model = torchvision.models.resnet18(num_classes=10).cuda()
 | |
| state_dict = torch.load(f"{args.checkpoint}/model_{args.epoch}.pth")
 | |
| model.load_state_dict(state_dict)
 | |
| 
 | |
| # ==============================
 | |
| # Run Evaluation
 | |
| # ==============================
 | |
| model.eval()
 | |
| 
 | |
| with torch.no_grad():
 | |
|     correct = 0
 | |
|     total = 0
 | |
|     for images, labels in test_loader:
 | |
|         images = images.cuda()
 | |
|         labels = labels.cuda()
 | |
|         outputs = model(images)
 | |
|         _, predicted = torch.max(outputs.data, 1)
 | |
|         total += labels.size(0)
 | |
|         correct += (predicted == labels).sum().item()
 | |
| 
 | |
|     print("Accuracy of the model on the test images: {} %".format(100 * correct / total))
 |