GanAttack/GanInverter/scripts/test.py

94 lines
3.4 KiB
Python

import os
import sys
sys.path.append('.')
sys.path.append('..')
import torch
import tqdm
from PIL import Image
from torch.utils.data import DataLoader
from datasets.inference_dataset import InversionDataset
from inference import TwoStageInference
from options.test_options import TestOptions
import torchvision.transforms as transforms
from criteria.lpips.lpips import LPIPS
def main():
opts = TestOptions().parse()
if opts.checkpoint_path is None:
opts.auto_resume = True
inversion = TwoStageInference(opts)
lpips_cri = LPIPS(net_type='alex').cuda().eval()
float2uint2float = lambda x: (((x + 1) / 2 * 255.).clamp(min=0, max=255).to(torch.uint8).float().div(255.) - 0.5) / 0.5
if opts.output_resolution is not None and len(opts.output_resolution) == 1:
opts.output_resolution = (opts.output_resolution, opts.output_resolution)
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
transform_no_resize = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
if os.path.isdir(opts.test_dataset_path):
dataset = InversionDataset(root=opts.test_dataset_path, transform=transform,
transform_no_resize=transform_no_resize)
dataloader = DataLoader(dataset,
batch_size=opts.test_batch_size,
shuffle=False,
num_workers=int(opts.test_workers),
drop_last=False)
else:
img = Image.open(opts.test_dataset_path)
img = img.convert('RGB')
img_aug = transform(img)
img_aug_no_resize = transform_no_resize(img)
dataloader = [(img_aug[None], [opts.test_dataset_path], img_aug_no_resize[None])]
lpips, count = 0, 0.
mse, psnr, id = torch.zeros([0]).cuda(), torch.zeros([0]).cuda(), torch.zeros([0]).cuda()
for input_batch in tqdm.tqdm(dataloader):
# Inversion
images_resize, img_paths, images = input_batch
images_resize, images = images_resize.cuda(), images.cuda()
count += len(img_paths)
emb_images, emb_codes, emb_info, refine_images, refine_codes, refine_info = \
inversion.inverse(images, images_resize, img_paths)
if refine_images is not None:
images_inv, codes = refine_images, refine_codes
else:
images_inv, codes = emb_images, emb_codes
# Evaluation
images_inv = float2uint2float(images_inv)
images_inv_resize = transforms.Resize((256, 256), antialias=True)(images_inv)
batch_mse, batch_psnr = calculate_mse_and_psnr(images_inv_resize, images_resize)
batch_lpips = lpips_cri(images_inv_resize, images_resize)
mse = torch.cat([mse, batch_mse])
psnr = torch.cat([psnr, batch_psnr])
lpips += len(img_paths) * batch_lpips.item()
print(f'Batch result: MSE {batch_mse.mean().item()}, PSNR {batch_psnr.mean().item()}')
print('MSE ', mse.mean().item())
print('PSNR:', psnr.mean().item())
print('LPIPS:', lpips / count)
def calculate_mse_and_psnr(img1, img2):
mse = ((img1 - img2) ** 2).mean(dim=[1, 2, 3])
psnr = 10 * torch.log10(2 / mse)
return mse, psnr
if __name__ == '__main__':
main()