GanAttack/utils.py

140 lines
4.4 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torchvision import models
import numpy as np
import PIL
def get_model(config):
if config.dataset == 'gender_dataset':
num_class=2
else:
num_class=307
# Changing number of model's output classes to 1
#for resnet18
if config.classifier.model == 'resnet18':
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(512, num_class)
#for resnet50
elif config.classifier.model == 'resnet101':
model = models.resnet101(pretrained=False)
model.fc = nn.Linear(2048, num_class)
# for densenet 121
elif config.classifier.model == 'mnasnet':
model = models.mnasnet1_0(pretrained=True)
num_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_features, num_class)
# model.classifier = nn.Linear(1024, 1)
#for vgg19_bn
elif config.classifier.model == 'densenet121':
model = models.densenet121(pretrained=True)
num_features = model.classifier.in_features
model.fc = nn.Linear(num_features, num_class)
# Transfer execution to GPU
model = model.to('cuda')
return model
def unnormalize(image):
mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float()
std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float()
image = image.detach().cpu()
image *= std
image += mean
image[image < 0] = 0
image[image > 1] = 1
return image
def normalize(image):
mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float().cuda()
std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float().cuda()
image = image.clone()
image -= mean
image /= std
return image
def set_requires_grad( nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
def preprocess(images, channel_order='RGB'):
"""Preprocesses the input images if needed.
This function assumes the input numpy array is with shape [batch_size,
height, width, channel]. Here, `channel = 3` for color image and
`channel = 1` for grayscale image. The returned images are with shape
[batch_size, channel, height, width].
NOTE: The channel order of input images is always assumed as `RGB`.
Args:
images: The raw inputs with dtype `numpy.uint8` and range [0, 255].
Returns:
The preprocessed images with dtype `numpy.float32` and range
[-1, 1].
"""
# input : numpy, np.uint8, 0~255, RGB, BHWC
# output : numpy, np.float32, -1~1, RGB, BCHW
image_channels = 3
max_val = 1.0
min_val = -1.0
if image_channels == 3 and channel_order == 'BGR':
images = images[:, :, :, ::-1]
images = images / 255.0 * (max_val - min_val) + min_val
images = images.astype(np.float32).transpose(0, 3, 1, 2)
return images
def postprocess(images):
"""Post-processes images from `torch.Tensor` to `numpy.ndarray`."""
# input : tensor, -1~1, RGB, BCHW
# output : np.uint8, 0~255, BGR, BHWC
images = images.detach().cpu().numpy()
images = (images + 1.) * 255. / 2.
images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
images = images.transpose(0, 2, 3, 1)[:,:,:,[2,1,0]]
return images
def Lanczos_resizing(image_target, resizing_tuple=(256,256)):
# input : -1~1, RGB, BCHW, Tensor
# output : -1~1, RGB, BCHW, Tensor
image_target_resized = image_target.clone().cpu().numpy()
image_target_resized = (image_target_resized + 1.) * 255. / 2.
image_target_resized = np.clip(image_target_resized + 0.5, 0, 255).astype(np.uint8)
image_target_resized = image_target_resized.transpose(0, 2, 3, 1)
tmps = []
for i in range(image_target_resized.shape[0]):
tmp = image_target_resized[i]
tmp = Image.fromarray(tmp) # PIL, 0~255, uint8, RGB, HWC
tmp = np.array(tmp.resize(resizing_tuple, PIL.Image.LANCZOS))
tmp = torch.from_numpy(preprocess(tmp[np.newaxis,:])).cuda()
tmps.append(tmp)
return torch.cat(tmps, dim=0)