GanAttack/utils.py

139 lines
4.4 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torchvision import models
import numpy as np
def get_model(config):
if config.dataset == 'gender_dataset':
num_class=2
else:
num_class=307
# Changing number of model's output classes to 1
#for resnet18
if config.classifier.model == 'resnet18':
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(512, num_class)
#for resnet50
elif config.classifier.model == 'resnet101':
model = models.resnet101(pretrained=False)
model.fc = nn.Linear(2048, num_class)
# for densenet 121
elif config.classifier.model == 'mnasnet':
model = models.mnasnet1_0(pretrained=True)
num_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_features, num_class)
# model.classifier = nn.Linear(1024, 1)
#for vgg19_bn
elif config.classifier.model == 'densenet121':
model = models.densenet121(pretrained=True)
num_features = model.classifier.in_features
model.fc = nn.Linear(num_features, num_class)
# Transfer execution to GPU
model = model.to('cuda')
return model
def unnormalize(image):
mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float()
std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float()
image = image.detach().cpu()
image *= std
image += mean
image[image < 0] = 0
image[image > 1] = 1
return image
def normalize(image):
mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float().cuda()
std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float().cuda()
image = image.clone()
image -= mean
image /= std
return image
def set_requires_grad( nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
def preprocess(images, channel_order='RGB'):
"""Preprocesses the input images if needed.
This function assumes the input numpy array is with shape [batch_size,
height, width, channel]. Here, `channel = 3` for color image and
`channel = 1` for grayscale image. The returned images are with shape
[batch_size, channel, height, width].
NOTE: The channel order of input images is always assumed as `RGB`.
Args:
images: The raw inputs with dtype `numpy.uint8` and range [0, 255].
Returns:
The preprocessed images with dtype `numpy.float32` and range
[-1, 1].
"""
# input : numpy, np.uint8, 0~255, RGB, BHWC
# output : numpy, np.float32, -1~1, RGB, BCHW
image_channels = 3
max_val = 1.0
min_val = -1.0
if image_channels == 3 and channel_order == 'BGR':
images = images[:, :, :, ::-1]
images = images / 255.0 * (max_val - min_val) + min_val
images = images.astype(np.float32).transpose(0, 3, 1, 2)
return images
def postprocess(images):
"""Post-processes images from `torch.Tensor` to `numpy.ndarray`."""
# input : tensor, -1~1, RGB, BCHW
# output : np.uint8, 0~255, BGR, BHWC
images = images.detach().cpu().numpy()
images = (images + 1.) * 255. / 2.
images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
images = images.transpose(0, 2, 3, 1)[:,:,:,[2,1,0]]
return images
def Lanczos_resizing(image_target, resizing_tuple=(256,256)):
# input : -1~1, RGB, BCHW, Tensor
# output : -1~1, RGB, BCHW, Tensor
image_target_resized = image_target.clone().cpu().numpy()
image_target_resized = (image_target_resized + 1.) * 255. / 2.
image_target_resized = np.clip(image_target_resized + 0.5, 0, 255).astype(np.uint8)
image_target_resized = image_target_resized.transpose(0, 2, 3, 1)
tmps = []
for i in range(image_target_resized.shape[0]):
tmp = image_target_resized[i]
tmp = Image.fromarray(tmp) # PIL, 0~255, uint8, RGB, HWC
tmp = np.array(tmp.resize(resizing_tuple, PIL.Image.LANCZOS))
tmp = torch.from_numpy(preprocess(tmp[np.newaxis,:])).cuda()
tmps.append(tmp)
return torch.cat(tmps, dim=0)