update all
This commit is contained in:
parent
59f2cad553
commit
cf67f37cd3
74
GanAttack.py
74
GanAttack.py
|
|
@ -1,16 +1,20 @@
|
|||
import sys
|
||||
sys.path.append('./GanInverter')
|
||||
import os
|
||||
sys.path.append('./pixel2style2pixel')
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torchvision import models
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from data.dataset import get_dataset,get_adv_dataset
|
||||
from utils import get_model,set_requires_grad
|
||||
from GanInverter.inference.two_stage_inference import TwoStageInference
|
||||
from GanInverter.models.stylegan2.model import Generator
|
||||
from utils import get_model,set_requires_grad,unnormalize
|
||||
# from GanInverter.inference.two_stage_inference import TwoStageInference
|
||||
# from GanInverter.models.stylegan2.model import Generator
|
||||
# from models import GanAttack
|
||||
import model
|
||||
from argparse import Namespace
|
||||
# from pixel2style2pixel.scripts.align_all_parallel import align_face
|
||||
from pixel2style2pixel.models.psp import pSp
|
||||
# from pixel2style2pixel.models.stylegan2.model import Generator
|
||||
from model import GanAttack,CLIPLoss,VggLoss,get_prompt
|
||||
import torch.nn.functional as F
|
||||
import time
|
||||
|
|
@ -22,39 +26,50 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|||
|
||||
|
||||
|
||||
def get_stylegan_generator(cfg):
|
||||
# def get_stylegan_generator(cfg):
|
||||
|
||||
# ensure_checkpoint_exists(ckpt_path)
|
||||
ckpt_path=cfg.paths.stylegan
|
||||
g_ema = Generator(1024, 512, 8)
|
||||
g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
|
||||
g_ema.eval()
|
||||
g_ema = g_ema.cuda()
|
||||
mean_latent = g_ema.mean_latent(4096)
|
||||
# # ensure_checkpoint_exists(ckpt_path)
|
||||
# ckpt_path=cfg.paths.stylegan
|
||||
# g_ema = Generator(1024, 512, 8)
|
||||
# g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
|
||||
# g_ema.eval()
|
||||
# g_ema = g_ema.cuda()
|
||||
# mean_latent = g_ema.mean_latent(4096)
|
||||
|
||||
return g_ema, mean_latent
|
||||
# return g_ema, mean_latent
|
||||
|
||||
def get_stylegan_inverter(cfg):
|
||||
|
||||
# ensure_checkpoint_exists(ckpt_path)
|
||||
path=cfg.paths.inverter_cfg
|
||||
inverter=TwoStageInference(opts=path)
|
||||
|
||||
return inverter.inverse
|
||||
ckpt = torch.load(path, map_location='cuda:0')
|
||||
opts = ckpt['opts']
|
||||
opts['checkpoint_path'] = path
|
||||
if 'learn_in_w' not in opts:
|
||||
opts['learn_in_w'] = False
|
||||
if 'output_size' not in opts:
|
||||
opts['output_size'] = 1024
|
||||
|
||||
net = pSp(Namespace(**opts))
|
||||
net.eval()
|
||||
net.cuda()
|
||||
return net
|
||||
|
||||
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
model=get_model(cfg)
|
||||
train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg)
|
||||
train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg)
|
||||
|
||||
classifier=get_model(cfg)
|
||||
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
|
||||
classifier.eval()
|
||||
g_ema, _=get_stylegan_generator(cfg)
|
||||
# g_ema, _=get_stylegan_generator(cfg)
|
||||
prompt=get_prompt(cfg.prompt)
|
||||
|
||||
net=get_stylegan_inverter(cfg)
|
||||
model=GanAttack(net,prompt).to(device)
|
||||
|
||||
inverter=get_stylegan_inverter(cfg)
|
||||
model=GanAttack(g_ema,inverter,images_resize=cfg.optim.images_resize,prompt=cfg.prompt).to(device)
|
||||
|
||||
prompt=get_prompt(cfg)
|
||||
|
||||
|
||||
num_epochs = cfg.optim.num_epochs
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
|
@ -72,14 +87,15 @@ def main(cfg: DictConfig) -> None:
|
|||
running_loss = 0
|
||||
running_corrects = 0
|
||||
|
||||
for i, (inputs,img_path, labels) in enumerate(train_dataloader):
|
||||
for i, (inputs, labels) in enumerate(train_dataloader):
|
||||
inputs = inputs.to(device)
|
||||
labels = labels.to(device)
|
||||
_, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
||||
codes = net.encoder(inputs)
|
||||
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
||||
optimizer.zero_grad()
|
||||
adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path)
|
||||
generated_img,adv_latent_codes=model(inputs)
|
||||
loss_vgg=vgg_loss(inputs,generated_img)
|
||||
loss_l1=F.l1_loss(clean_latent_codes,adv_latent_codes)
|
||||
loss_l1=F.l1_loss(codes,adv_latent_codes)
|
||||
loss_clip=clip_loss(generated_img,prompt)
|
||||
|
||||
_, preds = torch.max(classifier(generated_img), 1)
|
||||
|
|
@ -102,11 +118,11 @@ def main(cfg: DictConfig) -> None:
|
|||
running_loss = 0. #test_dataloader
|
||||
running_corrects = 0
|
||||
|
||||
for i, (inputs,img_path, labels) in enumerate(test_dataloader):
|
||||
for i, (inputs, labels) in enumerate(test_dataloader):
|
||||
inputs = inputs.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path)
|
||||
generated_img,adv_latent_codes=model(inputs)
|
||||
outputs = classifier(generated_img)
|
||||
_, preds = torch.max(outputs, 1)
|
||||
loss = criterion(outputs, labels)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ classifier:
|
|||
paths:
|
||||
gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset
|
||||
identity_dataset: ./dataset/CelebA_HQ_facial_identity_dataset
|
||||
inverter_cfg: secret
|
||||
inverter_cfg: checkpoint/psp_ffhq_encode.pt
|
||||
classifier: checkpoint
|
||||
stylegan: checkpoint/stylegan2-ffhq-config-f.pt
|
||||
adv_embedding: pretrained_models
|
||||
|
|
|
|||
28
model.py
28
model.py
|
|
@ -20,13 +20,13 @@ def get_prompt(cfg):
|
|||
# class GanAttack(nn.Module):
|
||||
|
||||
class GanAttack(nn.Module):
|
||||
def __init__(self, stylegan_generator, inverter,images_resize,prompt):
|
||||
def __init__(self, net,prompt):
|
||||
super().__init__()
|
||||
|
||||
self.generator = stylegan_generator
|
||||
self.generator.eval()
|
||||
self.inverter=inverter
|
||||
self.images_resize=images_resize
|
||||
self.net= net
|
||||
# self.generator.eval()
|
||||
# self.inverter=inverter
|
||||
# self.images_resize=images_resize
|
||||
self.prompt=prompt
|
||||
text_len=self.prompt.shape[0]
|
||||
self.mlp=nn.Sequential(
|
||||
|
|
@ -36,18 +36,22 @@ class GanAttack(nn.Module):
|
|||
)
|
||||
|
||||
|
||||
def forward(self, img,img_path):
|
||||
_, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
|
||||
x=latent_codes
|
||||
def forward(self, img):
|
||||
codes = self.net.encoder(img)
|
||||
codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
|
||||
# result_images, result_latent = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=False)
|
||||
# result_images = self.net.face_pool(result_images)
|
||||
# _, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
|
||||
x=codes
|
||||
batch_size=img.shape[0]
|
||||
prompt=self.prompt.repeat(batch_size).to(device)
|
||||
x_prompt=torch.cat([latent_codes,prompt],dim=1)
|
||||
x_prompt=torch.cat([codes,prompt],dim=1)
|
||||
x_prompt=self.mlp(x_prompt)
|
||||
x=x_prompt+x
|
||||
im,_=self.generator(x,input_is_latent=True, randomize_noise=False)
|
||||
im,_=self.net.decoder([x], input_is_latent=True, randomize_noise=False, return_latents=False)
|
||||
result_images = self.net.face_pool(im)
|
||||
|
||||
|
||||
return refine_images,im,x
|
||||
return result_images,x
|
||||
|
||||
class CLIPLoss(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue