mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
@@ -9,15 +8,15 @@ import torchvision.transforms as transforms
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-e', '--epoch', type=int, default=80, help="resume from the epoch's checkpoint")
|
||||
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
|
||||
parser.add_argument("-e", "--epoch", type=int, default=80, help="resume from the epoch's checkpoint")
|
||||
parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ==============================
|
||||
# Prepare Test Dataset
|
||||
# ==============================
|
||||
# CIFAR-10 dataset
|
||||
test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor())
|
||||
test_dataset = torchvision.datasets.CIFAR10(root="./data/", train=False, transform=transforms.ToTensor())
|
||||
|
||||
# Data loader
|
||||
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)
|
||||
@@ -26,7 +25,7 @@ test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128,
|
||||
# Load Model
|
||||
# ==============================
|
||||
model = torchvision.models.resnet18(num_classes=10).cuda()
|
||||
state_dict = torch.load(f'{args.checkpoint}/model_{args.epoch}.pth')
|
||||
state_dict = torch.load(f"{args.checkpoint}/model_{args.epoch}.pth")
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# ==============================
|
||||
@@ -45,4 +44,4 @@ with torch.no_grad():
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))
|
||||
print("Accuracy of the model on the test images: {} %".format(100 * correct / total))
|
||||
|
Reference in New Issue
Block a user