mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-28 20:30:42 +00:00
add example of self-supervised SimCLR training - V2 (#50)
* add example of self-supervised SimCLR training * simclr v2, replace nvidia dali dataloader * updated * sync to latest code writing style * sync to latest code writing style and modify README * detail README & standardize dataset path
This commit is contained in:
36
examples/simclr_cifar10_data_parallel/models/simclr.py
Normal file
36
examples/simclr_cifar10_data_parallel/models/simclr.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .Backbone import backbone
|
||||
|
||||
class projection_MLP(nn.Module):
|
||||
def __init__(self, in_dim, out_dim=256):
|
||||
super().__init__()
|
||||
hidden_dim = in_dim
|
||||
self.layer1 = nn.Sequential(
|
||||
nn.Linear(in_dim, hidden_dim),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.layer2 = nn.Linear(hidden_dim, out_dim)
|
||||
def forward(self, x):
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
return x
|
||||
|
||||
class SimCLR(nn.Module):
|
||||
|
||||
def __init__(self, model='resnet18', **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.backbone = backbone(model, **kwargs)
|
||||
self.projector = projection_MLP(self.backbone.output_dim)
|
||||
self.encoder = nn.Sequential(
|
||||
self.backbone,
|
||||
self.projector
|
||||
)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
|
||||
z1 = self.encoder(x1)
|
||||
z2 = self.encoder(x2)
|
||||
return z1, z2
|
||||
Reference in New Issue
Block a user