83 lines
2.2 KiB
Python
83 lines
2.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torchvision import models
|
|
|
|
def get_model(config):
|
|
if config.dataset == 'gender_dataset':
|
|
num_class=2
|
|
else:
|
|
num_class=307
|
|
|
|
# Changing number of model's output classes to 1
|
|
#for resnet18
|
|
if config.classifier.model == 'resnet18':
|
|
model = models.resnet18(pretrained=False)
|
|
model.fc = nn.Linear(512, num_class)
|
|
|
|
|
|
|
|
#for resnet50
|
|
elif config.classifier.model == 'resnet101':
|
|
model = models.resnet101(pretrained=False)
|
|
model.fc = nn.Linear(2048, num_class)
|
|
|
|
|
|
|
|
# for densenet 121
|
|
elif config.classifier.model == 'mnasnet':
|
|
model = models.mnasnet1_0(pretrained=True)
|
|
num_features = model.classifier[1].in_features
|
|
model.classifier[1] = nn.Linear(num_features, num_class)
|
|
# model.classifier = nn.Linear(1024, 1)
|
|
|
|
|
|
|
|
#for vgg19_bn
|
|
elif config.classifier.model == 'densenet121':
|
|
model = models.densenet121(pretrained=True)
|
|
num_features = model.classifier.in_features
|
|
model.fc = nn.Linear(num_features, num_class)
|
|
|
|
|
|
|
|
|
|
# Transfer execution to GPU
|
|
model = model.to('cuda')
|
|
|
|
return model
|
|
|
|
def unnormalize(image):
|
|
mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float()
|
|
std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float()
|
|
|
|
image = image.detach().cpu()
|
|
image *= std
|
|
image += mean
|
|
image[image < 0] = 0
|
|
image[image > 1] = 1
|
|
|
|
return image
|
|
|
|
def normalize(image):
|
|
mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float().cuda()
|
|
std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float().cuda()
|
|
|
|
image = image.clone()
|
|
image -= mean
|
|
image /= std
|
|
|
|
return image
|
|
|
|
def set_requires_grad( nets, requires_grad=False):
|
|
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
|
Parameters:
|
|
nets (network list) -- a list of networks
|
|
requires_grad (bool) -- whether the networks require gradients or not
|
|
"""
|
|
if not isinstance(nets, list):
|
|
nets = [nets]
|
|
for net in nets:
|
|
if net is not None:
|
|
for param in net.parameters():
|
|
param.requires_grad = requires_grad |