94 lines
3.4 KiB
Python
94 lines
3.4 KiB
Python
import os
|
|
import sys
|
|
|
|
sys.path.append('.')
|
|
sys.path.append('..')
|
|
|
|
import torch
|
|
import tqdm
|
|
from PIL import Image
|
|
from torch.utils.data import DataLoader
|
|
from datasets.inference_dataset import InversionDataset
|
|
from inference import TwoStageInference
|
|
from options.test_options import TestOptions
|
|
import torchvision.transforms as transforms
|
|
from criteria.lpips.lpips import LPIPS
|
|
|
|
|
|
def main():
|
|
|
|
opts = TestOptions().parse()
|
|
if opts.checkpoint_path is None:
|
|
opts.auto_resume = True
|
|
|
|
inversion = TwoStageInference(opts)
|
|
lpips_cri = LPIPS(net_type='alex').cuda().eval()
|
|
|
|
float2uint2float = lambda x: (((x + 1) / 2 * 255.).clamp(min=0, max=255).to(torch.uint8).float().div(255.) - 0.5) / 0.5
|
|
|
|
if opts.output_resolution is not None and len(opts.output_resolution) == 1:
|
|
opts.output_resolution = (opts.output_resolution, opts.output_resolution)
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize((256, 256)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
|
transform_no_resize = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
|
|
|
if os.path.isdir(opts.test_dataset_path):
|
|
dataset = InversionDataset(root=opts.test_dataset_path, transform=transform,
|
|
transform_no_resize=transform_no_resize)
|
|
dataloader = DataLoader(dataset,
|
|
batch_size=opts.test_batch_size,
|
|
shuffle=False,
|
|
num_workers=int(opts.test_workers),
|
|
drop_last=False)
|
|
else:
|
|
img = Image.open(opts.test_dataset_path)
|
|
img = img.convert('RGB')
|
|
img_aug = transform(img)
|
|
img_aug_no_resize = transform_no_resize(img)
|
|
dataloader = [(img_aug[None], [opts.test_dataset_path], img_aug_no_resize[None])]
|
|
|
|
lpips, count = 0, 0.
|
|
mse, psnr, id = torch.zeros([0]).cuda(), torch.zeros([0]).cuda(), torch.zeros([0]).cuda()
|
|
for input_batch in tqdm.tqdm(dataloader):
|
|
# Inversion
|
|
images_resize, img_paths, images = input_batch
|
|
images_resize, images = images_resize.cuda(), images.cuda()
|
|
count += len(img_paths)
|
|
emb_images, emb_codes, emb_info, refine_images, refine_codes, refine_info = \
|
|
inversion.inverse(images, images_resize, img_paths)
|
|
|
|
if refine_images is not None:
|
|
images_inv, codes = refine_images, refine_codes
|
|
else:
|
|
images_inv, codes = emb_images, emb_codes
|
|
|
|
# Evaluation
|
|
images_inv = float2uint2float(images_inv)
|
|
images_inv_resize = transforms.Resize((256, 256), antialias=True)(images_inv)
|
|
batch_mse, batch_psnr = calculate_mse_and_psnr(images_inv_resize, images_resize)
|
|
batch_lpips = lpips_cri(images_inv_resize, images_resize)
|
|
|
|
mse = torch.cat([mse, batch_mse])
|
|
psnr = torch.cat([psnr, batch_psnr])
|
|
lpips += len(img_paths) * batch_lpips.item()
|
|
print(f'Batch result: MSE {batch_mse.mean().item()}, PSNR {batch_psnr.mean().item()}')
|
|
|
|
print('MSE ', mse.mean().item())
|
|
print('PSNR:', psnr.mean().item())
|
|
print('LPIPS:', lpips / count)
|
|
|
|
|
|
def calculate_mse_and_psnr(img1, img2):
|
|
mse = ((img1 - img2) ** 2).mean(dim=[1, 2, 3])
|
|
psnr = 10 * torch.log10(2 / mse)
|
|
return mse, psnr
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|