139 lines
4.4 KiB
Python
139 lines
4.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from PIL import Image
|
|
from torchvision import models
|
|
import numpy as np
|
|
|
|
def get_model(config):
|
|
if config.dataset == 'gender_dataset':
|
|
num_class=2
|
|
else:
|
|
num_class=307
|
|
|
|
# Changing number of model's output classes to 1
|
|
#for resnet18
|
|
if config.classifier.model == 'resnet18':
|
|
model = models.resnet18(pretrained=False)
|
|
model.fc = nn.Linear(512, num_class)
|
|
|
|
|
|
|
|
#for resnet50
|
|
elif config.classifier.model == 'resnet101':
|
|
model = models.resnet101(pretrained=False)
|
|
model.fc = nn.Linear(2048, num_class)
|
|
|
|
|
|
|
|
# for densenet 121
|
|
elif config.classifier.model == 'mnasnet':
|
|
model = models.mnasnet1_0(pretrained=True)
|
|
num_features = model.classifier[1].in_features
|
|
model.classifier[1] = nn.Linear(num_features, num_class)
|
|
# model.classifier = nn.Linear(1024, 1)
|
|
|
|
|
|
|
|
#for vgg19_bn
|
|
elif config.classifier.model == 'densenet121':
|
|
model = models.densenet121(pretrained=True)
|
|
num_features = model.classifier.in_features
|
|
model.fc = nn.Linear(num_features, num_class)
|
|
|
|
|
|
|
|
|
|
# Transfer execution to GPU
|
|
model = model.to('cuda')
|
|
|
|
return model
|
|
|
|
def unnormalize(image):
|
|
mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float()
|
|
std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float()
|
|
|
|
image = image.detach().cpu()
|
|
image *= std
|
|
image += mean
|
|
image[image < 0] = 0
|
|
image[image > 1] = 1
|
|
|
|
return image
|
|
|
|
def normalize(image):
|
|
mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float().cuda()
|
|
std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float().cuda()
|
|
|
|
image = image.clone()
|
|
image -= mean
|
|
image /= std
|
|
|
|
return image
|
|
|
|
def set_requires_grad( nets, requires_grad=False):
|
|
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
|
Parameters:
|
|
nets (network list) -- a list of networks
|
|
requires_grad (bool) -- whether the networks require gradients or not
|
|
"""
|
|
if not isinstance(nets, list):
|
|
nets = [nets]
|
|
for net in nets:
|
|
if net is not None:
|
|
for param in net.parameters():
|
|
param.requires_grad = requires_grad
|
|
|
|
def preprocess(images, channel_order='RGB'):
|
|
"""Preprocesses the input images if needed.
|
|
This function assumes the input numpy array is with shape [batch_size,
|
|
height, width, channel]. Here, `channel = 3` for color image and
|
|
`channel = 1` for grayscale image. The returned images are with shape
|
|
[batch_size, channel, height, width].
|
|
NOTE: The channel order of input images is always assumed as `RGB`.
|
|
Args:
|
|
images: The raw inputs with dtype `numpy.uint8` and range [0, 255].
|
|
Returns:
|
|
The preprocessed images with dtype `numpy.float32` and range
|
|
[-1, 1].
|
|
"""
|
|
# input : numpy, np.uint8, 0~255, RGB, BHWC
|
|
# output : numpy, np.float32, -1~1, RGB, BCHW
|
|
|
|
image_channels = 3
|
|
max_val = 1.0
|
|
min_val = -1.0
|
|
|
|
if image_channels == 3 and channel_order == 'BGR':
|
|
images = images[:, :, :, ::-1]
|
|
images = images / 255.0 * (max_val - min_val) + min_val
|
|
images = images.astype(np.float32).transpose(0, 3, 1, 2)
|
|
return images
|
|
|
|
def postprocess(images):
|
|
"""Post-processes images from `torch.Tensor` to `numpy.ndarray`."""
|
|
# input : tensor, -1~1, RGB, BCHW
|
|
# output : np.uint8, 0~255, BGR, BHWC
|
|
|
|
images = images.detach().cpu().numpy()
|
|
images = (images + 1.) * 255. / 2.
|
|
images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
|
|
images = images.transpose(0, 2, 3, 1)[:,:,:,[2,1,0]]
|
|
return images
|
|
|
|
def Lanczos_resizing(image_target, resizing_tuple=(256,256)):
|
|
# input : -1~1, RGB, BCHW, Tensor
|
|
# output : -1~1, RGB, BCHW, Tensor
|
|
image_target_resized = image_target.clone().cpu().numpy()
|
|
image_target_resized = (image_target_resized + 1.) * 255. / 2.
|
|
image_target_resized = np.clip(image_target_resized + 0.5, 0, 255).astype(np.uint8)
|
|
|
|
image_target_resized = image_target_resized.transpose(0, 2, 3, 1)
|
|
tmps = []
|
|
for i in range(image_target_resized.shape[0]):
|
|
tmp = image_target_resized[i]
|
|
tmp = Image.fromarray(tmp) # PIL, 0~255, uint8, RGB, HWC
|
|
tmp = np.array(tmp.resize(resizing_tuple, PIL.Image.LANCZOS))
|
|
tmp = torch.from_numpy(preprocess(tmp[np.newaxis,:])).cuda()
|
|
tmps.append(tmp)
|
|
return torch.cat(tmps, dim=0) |