GanAttack/utils.py

83 lines
2.2 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
def get_model(config):
if config.dataset == 'gender_dataset':
num_class=2
else:
num_class=307
# Changing number of model's output classes to 1
#for resnet18
if config.classifier.model == 'resnet18':
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(512, num_class)
#for resnet50
elif config.classifier.model == 'resnet101':
model = models.resnet101(pretrained=False)
model.fc = nn.Linear(2048, num_class)
# for densenet 121
elif config.classifier.model == 'mnasnet':
model = models.mnasnet1_0(pretrained=True)
num_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_features, num_class)
# model.classifier = nn.Linear(1024, 1)
#for vgg19_bn
elif config.classifier.model == 'densenet121':
model = models.densenet121(pretrained=True)
num_features = model.classifier.in_features
model.fc = nn.Linear(num_features, num_class)
# Transfer execution to GPU
model = model.to('cuda')
return model
def unnormalize(image):
mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float()
std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float()
image = image.detach().cpu()
image *= std
image += mean
image[image < 0] = 0
image[image > 1] = 1
return image
def normalize(image):
mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float().cuda()
std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float().cuda()
image = image.clone()
image -= mean
image /= std
return image
def set_requires_grad( nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad