add example of self-supervised SimCLR training - V2 (#50)

* add example of self-supervised SimCLR training

* simclr v2, replace nvidia dali dataloader

* updated

* sync to latest code writing style

* sync to latest code writing style and modify README

* detail README & standardize dataset path
This commit is contained in:
Xin Zhang
2021-12-21 08:07:18 +08:00
committed by GitHub
parent 8f02a88db2
commit 648f806315
19 changed files with 688 additions and 0 deletions

View File

@@ -0,0 +1,72 @@
import torch
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from models.simclr import SimCLR
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision import transforms
log_name = 'cifar-simclr'
epoch = 800
fea_flag = True
tsne_flag = True
plot_flag = True
if fea_flag:
path = f'ckpt/{log_name}/epoch{epoch}-tp0-pp0.pt'
net = SimCLR('resnet18').cuda()
print(net.load_state_dict(torch.load(path)['model']))
transform_eval = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])
train_dataset = CIFAR10(root='./dataset', train=True, transform=transform_eval)
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=False, num_workers=4)
test_dataset = CIFAR10(root='./dataset', train=False, transform=transform_eval)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4)
def feature_extractor(model, loader):
model.eval()
all_fea = []
all_targets = []
for img, target in loader:
img = img.cuda()
fea = model.backbone(img)
all_fea.append(fea.detach().cpu())
all_targets.append(target)
all_fea = torch.cat(all_fea)
all_targets = torch.cat(all_targets)
return all_fea.numpy(), all_targets.numpy()
if tsne_flag:
train_fea, train_targets = feature_extractor(net, train_dataloader)
train_embedded = TSNE(n_components=2).fit_transform(train_fea)
test_fea, test_targets = feature_extractor(net, test_dataloader)
test_embedded = TSNE(n_components=2).fit_transform(test_fea)
np.savez('results/embedding.npz', train_embedded=train_embedded, train_targets=train_targets, test_embedded=test_embedded, test_targets=test_targets)
if plot_flag:
npz = np.load('embedding.npz')
train_embedded = npz['train_embedded']
train_targets = npz['train_targets']
test_embedded = npz['test_embedded']
test_targets = npz['test_targets']
plt.figure(figsize=(16,16))
for i in range(len(np.unique(train_targets))):
plt.scatter(train_embedded[train_targets==i,0], train_embedded[train_targets==i,1], label=i)
plt.title('train')
plt.legend()
plt.savefig('results/train_tsne.png')
plt.figure(figsize=(16,16))
for i in range(len(np.unique(test_targets))):
plt.scatter(test_embedded[test_targets==i,0], test_embedded[test_targets==i,1], label=i)
plt.title('test')
plt.legend()
plt.savefig('results/test_tsne.png')