import math import os import copy import cv2 import numpy as np import torch import torch.nn.functional as F import torch.optim as optim from tqdm import tqdm from skimage.segmentation import slic from criteria.lpips.lpips import LPIPS from inference.inference import BaseInference from models.stylegan2.model import Generator from utils.train_utils import load_train_checkpoint from .optim_infer import OptimizerInference import utils.facer.facer as facer class DHRInference(BaseInference): def __init__(self, opts, **kwargs): super(DHRInference, self).__init__() self.opts = opts self.device = 'cuda' self.opts.device = self.device self.opts.n_styles = int(math.log(opts.resolution, 2)) * 2 - 2 # initial loss self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval() self.checkpoint = load_train_checkpoint(self.opts) # coarse inversion self.coarse_inv = OptimizerInference(opts) # parsing self.detector = facer.face_detector('retinaface/mobilenet', device=self.device) self.parser_celeb = facer.face_parser(f'farl/celebm/448', device=self.device) def parsing(self, img): img = ((img / 2 + 0.5) * 255.).to(torch.uint8) with torch.no_grad(): faces = self.detector(img, threshold=0.1) scores = faces['scores'] max_idx = scores.argmax() faces = {k: v[max_idx][None] for k, v in faces.items()} faces_celeb = self.parser_celeb(img, copy.deepcopy(faces)) parsing_result = faces_celeb['seg']['logits'].softmax(dim=1)[0]#.cpu().numpy() return parsing_result def inverse(self, images, images_resize, image_name, emb_codes, emb_images, emb_info): assert images.shape[0] == 1, 'DHR is only supported for batchsize 1.' refine_info = dict() # initialize decoder and regularization decoder feature_idx = self.opts.dhr_feature_idx # 11 res = [4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 512, 1024, 1024][feature_idx] decoder = Generator(self.opts.resolution, 512, 8).to(self.device) decoder.train() if self.checkpoint is not None: decoder.load_state_dict(self.checkpoint['decoder'], strict=False) else: decoder_checkpoint = torch.load(self.opts.stylegan_weights, map_location='cpu') decoder.load_state_dict(decoder_checkpoint['g_ema']) # domain-specific segmentation bg_score = 0.6 # no use fg_score = self.opts.dhr_theta1 top_score = self.opts.dhr_theta2 score_thr = [bg_score, bg_score, fg_score, bg_score] + [top_score] * 10 + [bg_score, bg_score, bg_score, bg_score, bg_score] # Domain-Specific Segmentation if os.path.exists(f'{self.opts.output_dir}/mask_refine_pt/{os.path.basename(image_name[0])[:-4]}.pt'): # load segmentation if exist m = torch.load(f'{self.opts.output_dir}/mask_refine_pt/{os.path.basename(image_name[0])[:-4]}.pt') m = m.cuda() else: # coarse inversion coarse_image, _, delta = self.coarse_inv.inverse(images, images_resize, image_name, return_lpips=True) refine_info['coarse_inv'] = coarse_image # face parsing parsing_result = self.parsing(images) mask_bg = parsing_result[[0, 3, -1, -2, -3]].sum(dim=0) mask_parsing = (mask_bg < 0.5).float() mask_parsing = mask_parsing[None, None] # Superpixel superpixel = slic(cv2.imread(image_name[0]), n_segments=200, compactness=30, sigma=1) mask_sp = torch.zeros((self.opts.resolution, self.opts.resolution)) for sp_i in range(1, 1 + int(superpixel.max())): ds = [] mask = torch.Tensor(superpixel == sp_i).float()[None, None].cuda() for idx, d in enumerate(delta): shape = d.shape[2] m = torch.nn.functional.interpolate(mask, (shape, shape)) ds.append((d * m).sum() / m.sum()) parsing_idx = (mask[0] * parsing_result).sum(dim=[-1, -2]).argmax().item() mask_sp[superpixel == sp_i] = (sum(ds) < score_thr[parsing_idx]).float() # domain-specific segmentation result m = mask_sp[None, None].cuda() * mask_parsing refine_info['mask_refine_pt'] = m.clone() refine_info['mask_refine'] = (m[0, 0].cpu().numpy() * 255.).astype(np.uint8) refine_info['mask_superpixel'] = (mask_sp.clone().numpy() * 255.).astype(np.uint8) refine_info['mask_parsing'] = (mask_parsing[0, 0].cpu().numpy() * 255.).astype(np.uint8) mask = torch.nn.AdaptiveAvgPool2d((res, res))(m) mask_ori = m.clone() # Hybrid Refinement Modulation if os.path.exists( f'{self.opts.output_dir}/weight/{os.path.basename(image_name[0])[:-4]}.pt') and os.path.exists( f'{self.opts.output_dir}/feature/{os.path.basename(image_name[0])[:-4]}.pt'): # load weight and feature if exist. weight = torch.load(f'{self.opts.output_dir}/weight/{os.path.basename(image_name[0])[:-4]}.pt', map_location='cpu') decoder.load_state_dict(weight) offset = torch.load(f'{self.opts.output_dir}/feature/{os.path.basename(image_name[0])[:-4]}.pt') else: # initialize modulated feature and optimizer with torch.no_grad(): gen_images, _, offset = decoder([emb_codes], feature_idx=feature_idx, mask=None, input_is_latent=True, randomize_noise=False, return_featuremap=True) offset = torch.nn.Parameter(offset.detach()).cuda() optimizer_f = optim.Adam([offset], lr=self.opts.dhr_feature_lr) optimizer = optim.Adam(decoder.parameters(), lr=self.opts.dhr_weight_lr) for i in range(self.opts.dhr_feature_step): gen_images, _ = decoder([emb_codes], feature_idx=feature_idx, offset=offset, mask=mask, input_is_latent=True, randomize_noise=False) # calculate loss loss_lpips = self.lpips_loss(gen_images, images, keep_res=True) loss_mse = (F.mse_loss(gen_images, images, reduction='none')).mean() lpips_face, lpips_bg = [], [] for idx, lpips in enumerate(loss_lpips): shape = lpips.shape[2] m = torch.nn.functional.interpolate(mask, (shape, shape)) lpips_face.append((lpips * m).sum() / m.sum()) lpips_bg.append((lpips * (1 - m)).sum() / (1 - m).sum()) loss_lpips = torch.stack([l.mean() for l in loss_lpips]).sum() # loss = self.opts.dhr_lpips_lambda * loss_lpips + self.opts.dhr_l2_lambda * loss_mse loss_lpips_bg = torch.stack(lpips_bg).sum() loss_mse_bg = ((F.mse_loss(gen_images, images, reduction='none') * (1 - mask_ori)).sum() / ( 1 - mask_ori).sum()) loss_bg = self.opts.dhr_lpips_lambda * loss_lpips_bg + self.opts.dhr_l2_lambda * loss_mse_bg optimizer_f.zero_grad() loss_bg.backward(retain_graph=True) optimizer_f.step() if i < self.opts.dhr_weight_step: loss_lpips_face = torch.stack(lpips_face).sum() loss_mse_face = ( (F.mse_loss(gen_images, images, reduction='none') * mask_ori).sum() / mask_ori.sum()) loss_face = self.opts.dhr_lpips_lambda * loss_lpips_face + self.opts.dhr_l2_lambda * loss_mse_face optimizer.zero_grad() loss_face.backward() optimizer.step() refine_info['weight'] = decoder.state_dict() refine_info['feature'] = offset.clone() refine_info['mask'] = mask.clone() images, result_latent = decoder([emb_codes], feature_idx=feature_idx, offset=offset, mask=mask, input_is_latent=True, randomize_noise=False) return images, emb_codes, [refine_info] def edit(self, images, images_resize, image_paths, emb_codes, emb_images, emb_info, editor): images, codes, refine_info = self.inverse(images, images_resize, image_paths, emb_codes, emb_images, emb_info) refine_info = refine_info[0] with torch.no_grad(): decoder = Generator(self.opts.resolution, 512, 8).to(self.device) decoder.train() decoder.load_state_dict(refine_info['weight'], strict=True) edit_codes = editor.edit_code(codes) edit_images, edit_codes = decoder([edit_codes], feature_idx=self.opts.dhr_feature_idx, offset=refine_info['feature'], mask=refine_info['mask'], input_is_latent=True, randomize_noise=False) return images, edit_images, codes, edit_codes, refine_info