GanAttack/GanInverter/inference/pti_infer.py

137 lines
6.0 KiB
Python

from tqdm import tqdm
from criteria.lpips.lpips import LPIPS
import math
from models.stylegan2.model import Generator
import torch
import torch.optim as optim
import torch.nn.functional as F
from utils.train_utils import load_train_checkpoint
from inference.inference import BaseInference
class Space_Regulizer:
def __init__(self, opts, original_G, lpips_net):
self.opts = opts
self.device = 'cuda'
self.original_G = original_G
self.morphing_regulizer_alpha = opts.pti_regulizer_alpha
self.lpips_loss = lpips_net
self.w_mean = original_G.mean_latent(100000).detach()
def get_morphed_w_code(self, new_w_code, fixed_w):
interpolation_direction = new_w_code - fixed_w
interpolation_direction_norm = torch.norm(interpolation_direction, p=2)
direction_to_move = self.morphing_regulizer_alpha * interpolation_direction / interpolation_direction_norm
result_w = fixed_w + direction_to_move
self.morphing_regulizer_alpha * fixed_w + (1 - self.morphing_regulizer_alpha) * new_w_code
return result_w
def get_image_from_ws(self, w_codes, G):
return torch.cat([G(w_code, input_is_latent=True, noise=None, randomize_noise=False)[0] for w_code in w_codes])
def ball_holder_loss_lazy(self, new_G, num_of_sampled_latents, w_batch):
loss = 0.0
w_samples = self.original_G.w_sample(num_of_sampled_latents)
w_samples = 0.5 * w_samples + 0.5 * self.w_mean
territory_indicator_ws = [self.get_morphed_w_code(w_code.unsqueeze(0), w_batch) for w_code in w_samples]
for w_code in territory_indicator_ws:
new_img, _ = new_G(w_code, input_is_latent=True, noise=None, randomize_noise=False)
with torch.no_grad():
old_img, _ = self.original_G(w_code, input_is_latent=True, noise=None, randomize_noise=False)
if self.opts.pti_regulizer_l2_lambda > 0:
l2_loss_val = torch.nn.MSELoss(reduction='mean')(old_img, new_img)
loss += l2_loss_val * self.opts.pti_regulizer_l2_lambda
if self.opts.pti_regulizer_lpips_lambda > 0:
loss_lpips = self.lpips_loss(old_img, new_img)
loss_lpips = torch.mean(torch.squeeze(loss_lpips))
loss += loss_lpips * self.opts.pti_regulizer_lpips_lambda
return loss / len(territory_indicator_ws)
def space_regulizer_loss(self, new_G, w_batch):
ret_val = self.ball_holder_loss_lazy(new_G, self.opts.pti_latent_ball_num_of_samples, w_batch)
return ret_val
class PTIInference(BaseInference):
def __init__(self, opts):
super(PTIInference, self).__init__()
self.opts = opts
self.device = 'cuda'
self.opts.device = self.device
self.opts.n_styles = int(math.log(opts.resolution, 2)) * 2 - 2
# initial loss
self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()
self.checkpoint = load_train_checkpoint(self.opts)
origin_decoder = Generator(self.opts.resolution, 512, 8).to(self.device)
if self.checkpoint is not None:
origin_decoder.load_state_dict(self.checkpoint['decoder'], strict=True)
else:
decoder_checkpoint = torch.load(opts.stylegan_weights, map_location='cpu')
origin_decoder.load_state_dict(decoder_checkpoint['g_ema'])
if opts.pti_use_regularization:
self.space_regulizer = Space_Regulizer(opts, origin_decoder, self.lpips_loss)
def inverse(self, images, images_resize, image_paths, emb_codes, emb_images, emb_info):
assert images.shape[0] == 1, 'PTI is only supported for batchsize 1.'
# initialize decoder and regularization decoder
decoder = Generator(self.opts.resolution, 512, 8).to(self.device)
decoder.train()
if self.checkpoint is not None:
decoder.load_state_dict(self.checkpoint['decoder'], strict=True)
else:
decoder_checkpoint = torch.load(self.opts.stylegan_weights, map_location='cpu')
decoder.load_state_dict(decoder_checkpoint['g_ema'])
# initialize optimizer
optimizer = optim.Adam(decoder.parameters(), lr=self.opts.pti_lr)
pbar = tqdm(range(self.opts.pti_step))
for i in pbar:
gen_images, _ = decoder([emb_codes], input_is_latent=True, randomize_noise=False)
# calculate loss
loss_lpips = self.lpips_loss(gen_images, images)
loss_mse = F.mse_loss(gen_images, images)
loss = self.opts.pti_lpips_lambda * loss_lpips + self.opts.pti_l2_lambda * loss_mse
# TODO: use regularization may cause some erros
# if self.opts.pti_use_regularization and i % self.opts.pti_locality_regularization_interval == 0:
# ball_holder_loss_val = self.space_regulizer.space_regulizer_loss(decoder, emb_codes)
# loss += self.opts.pti_regulizer_lambda * ball_holder_loss_val
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(
(
f"loss: {loss.item():.4f}; lr: {self.opts.pti_lr:.4f};"
)
)
with torch.no_grad():
images, result_latent = decoder([emb_codes], input_is_latent=True, randomize_noise=False)
pti_info = [{'generator': decoder.state_dict()}]
return images, emb_codes, pti_info
def edit(self, images, images_resize, image_paths, emb_codes, emb_images, emb_info, editor):
images, codes, refine_info = self.inverse(images, images_resize, image_paths, emb_codes, emb_images, emb_info)
refine_info = refine_info[0]
with torch.no_grad():
decoder = Generator(self.opts.resolution, 512, 8).to(self.device)
decoder.train()
decoder.load_state_dict(refine_info['generator'], strict=True)
edit_codes = editor.edit_code(codes)
edit_images, edit_codes = decoder([edit_codes], input_is_latent=True, randomize_noise=False)
return images, edit_images, codes, edit_codes, refine_info