diff --git a/GanAttack.py b/GanAttack.py index e2f5944..7a1f925 100644 --- a/GanAttack.py +++ b/GanAttack.py @@ -1,16 +1,20 @@ import sys -sys.path.append('./GanInverter') +import os +sys.path.append('./pixel2style2pixel') import torch import torch.nn as nn import torch.optim as optim from torchvision import models from omegaconf import DictConfig, OmegaConf from data.dataset import get_dataset,get_adv_dataset -from utils import get_model,set_requires_grad -from GanInverter.inference.two_stage_inference import TwoStageInference -from GanInverter.models.stylegan2.model import Generator +from utils import get_model,set_requires_grad,unnormalize +# from GanInverter.inference.two_stage_inference import TwoStageInference +# from GanInverter.models.stylegan2.model import Generator # from models import GanAttack -import model +from argparse import Namespace +# from pixel2style2pixel.scripts.align_all_parallel import align_face +from pixel2style2pixel.models.psp import pSp +# from pixel2style2pixel.models.stylegan2.model import Generator from model import GanAttack,CLIPLoss,VggLoss,get_prompt import torch.nn.functional as F import time @@ -22,39 +26,50 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -def get_stylegan_generator(cfg): +# def get_stylegan_generator(cfg): - # ensure_checkpoint_exists(ckpt_path) - ckpt_path=cfg.paths.stylegan - g_ema = Generator(1024, 512, 8) - g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False) - g_ema.eval() - g_ema = g_ema.cuda() - mean_latent = g_ema.mean_latent(4096) +# # ensure_checkpoint_exists(ckpt_path) +# ckpt_path=cfg.paths.stylegan +# g_ema = Generator(1024, 512, 8) +# g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False) +# g_ema.eval() +# g_ema = g_ema.cuda() +# mean_latent = g_ema.mean_latent(4096) - return g_ema, mean_latent +# return g_ema, mean_latent def get_stylegan_inverter(cfg): # ensure_checkpoint_exists(ckpt_path) path=cfg.paths.inverter_cfg - inverter=TwoStageInference(opts=path) - - return inverter.inverse + ckpt = torch.load(path, map_location='cuda:0') + opts = ckpt['opts'] + opts['checkpoint_path'] = path + if 'learn_in_w' not in opts: + opts['learn_in_w'] = False + if 'output_size' not in opts: + opts['output_size'] = 1024 + + net = pSp(Namespace(**opts)) + net.eval() + net.cuda() + return net @hydra.main(version_base=None, config_path="./config", config_name="config") def main(cfg: DictConfig) -> None: model=get_model(cfg) - train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg) + train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg) classifier=get_model(cfg) + classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset))) classifier.eval() - g_ema, _=get_stylegan_generator(cfg) + # g_ema, _=get_stylegan_generator(cfg) + prompt=get_prompt(cfg.prompt) + + net=get_stylegan_inverter(cfg) + model=GanAttack(net,prompt).to(device) - inverter=get_stylegan_inverter(cfg) - model=GanAttack(g_ema,inverter,images_resize=cfg.optim.images_resize,prompt=cfg.prompt).to(device) - - prompt=get_prompt(cfg) + num_epochs = cfg.optim.num_epochs criterion = nn.CrossEntropyLoss() @@ -72,14 +87,15 @@ def main(cfg: DictConfig) -> None: running_loss = 0 running_corrects = 0 - for i, (inputs,img_path, labels) in enumerate(train_dataloader): + for i, (inputs, labels) in enumerate(train_dataloader): inputs = inputs.to(device) labels = labels.to(device) - _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path) + codes = net.encoder(inputs) + # _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path) optimizer.zero_grad() - adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path) + generated_img,adv_latent_codes=model(inputs) loss_vgg=vgg_loss(inputs,generated_img) - loss_l1=F.l1_loss(clean_latent_codes,adv_latent_codes) + loss_l1=F.l1_loss(codes,adv_latent_codes) loss_clip=clip_loss(generated_img,prompt) _, preds = torch.max(classifier(generated_img), 1) @@ -102,11 +118,11 @@ def main(cfg: DictConfig) -> None: running_loss = 0. #test_dataloader running_corrects = 0 - for i, (inputs,img_path, labels) in enumerate(test_dataloader): + for i, (inputs, labels) in enumerate(test_dataloader): inputs = inputs.to(device) labels = labels.to(device) - adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path) + generated_img,adv_latent_codes=model(inputs) outputs = classifier(generated_img) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) diff --git a/config/config.yaml b/config/config.yaml index f86c768..bf93078 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -11,7 +11,7 @@ classifier: paths: gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset identity_dataset: ./dataset/CelebA_HQ_facial_identity_dataset - inverter_cfg: secret + inverter_cfg: checkpoint/psp_ffhq_encode.pt classifier: checkpoint stylegan: checkpoint/stylegan2-ffhq-config-f.pt adv_embedding: pretrained_models diff --git a/model.py b/model.py index b61e262..e1015dc 100644 --- a/model.py +++ b/model.py @@ -20,13 +20,13 @@ def get_prompt(cfg): # class GanAttack(nn.Module): class GanAttack(nn.Module): - def __init__(self, stylegan_generator, inverter,images_resize,prompt): + def __init__(self, net,prompt): super().__init__() - self.generator = stylegan_generator - self.generator.eval() - self.inverter=inverter - self.images_resize=images_resize + self.net= net + # self.generator.eval() + # self.inverter=inverter + # self.images_resize=images_resize self.prompt=prompt text_len=self.prompt.shape[0] self.mlp=nn.Sequential( @@ -36,18 +36,22 @@ class GanAttack(nn.Module): ) - def forward(self, img,img_path): - _, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path) - x=latent_codes + def forward(self, img): + codes = self.net.encoder(img) + codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1) + # result_images, result_latent = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=False) + # result_images = self.net.face_pool(result_images) + # _, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path) + x=codes batch_size=img.shape[0] prompt=self.prompt.repeat(batch_size).to(device) - x_prompt=torch.cat([latent_codes,prompt],dim=1) + x_prompt=torch.cat([codes,prompt],dim=1) x_prompt=self.mlp(x_prompt) x=x_prompt+x - im,_=self.generator(x,input_is_latent=True, randomize_noise=False) + im,_=self.net.decoder([x], input_is_latent=True, randomize_noise=False, return_latents=False) + result_images = self.net.face_pool(im) - - return refine_images,im,x + return result_images,x class CLIPLoss(torch.nn.Module): def __init__(self):