update all
This commit is contained in:
parent
59f2cad553
commit
cf67f37cd3
74
GanAttack.py
74
GanAttack.py
|
|
@ -1,16 +1,20 @@
|
||||||
import sys
|
import sys
|
||||||
sys.path.append('./GanInverter')
|
import os
|
||||||
|
sys.path.append('./pixel2style2pixel')
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torchvision import models
|
from torchvision import models
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from data.dataset import get_dataset,get_adv_dataset
|
from data.dataset import get_dataset,get_adv_dataset
|
||||||
from utils import get_model,set_requires_grad
|
from utils import get_model,set_requires_grad,unnormalize
|
||||||
from GanInverter.inference.two_stage_inference import TwoStageInference
|
# from GanInverter.inference.two_stage_inference import TwoStageInference
|
||||||
from GanInverter.models.stylegan2.model import Generator
|
# from GanInverter.models.stylegan2.model import Generator
|
||||||
# from models import GanAttack
|
# from models import GanAttack
|
||||||
import model
|
from argparse import Namespace
|
||||||
|
# from pixel2style2pixel.scripts.align_all_parallel import align_face
|
||||||
|
from pixel2style2pixel.models.psp import pSp
|
||||||
|
# from pixel2style2pixel.models.stylegan2.model import Generator
|
||||||
from model import GanAttack,CLIPLoss,VggLoss,get_prompt
|
from model import GanAttack,CLIPLoss,VggLoss,get_prompt
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import time
|
import time
|
||||||
|
|
@ -22,39 +26,50 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_stylegan_generator(cfg):
|
# def get_stylegan_generator(cfg):
|
||||||
|
|
||||||
# ensure_checkpoint_exists(ckpt_path)
|
# # ensure_checkpoint_exists(ckpt_path)
|
||||||
ckpt_path=cfg.paths.stylegan
|
# ckpt_path=cfg.paths.stylegan
|
||||||
g_ema = Generator(1024, 512, 8)
|
# g_ema = Generator(1024, 512, 8)
|
||||||
g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
|
# g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
|
||||||
g_ema.eval()
|
# g_ema.eval()
|
||||||
g_ema = g_ema.cuda()
|
# g_ema = g_ema.cuda()
|
||||||
mean_latent = g_ema.mean_latent(4096)
|
# mean_latent = g_ema.mean_latent(4096)
|
||||||
|
|
||||||
return g_ema, mean_latent
|
# return g_ema, mean_latent
|
||||||
|
|
||||||
def get_stylegan_inverter(cfg):
|
def get_stylegan_inverter(cfg):
|
||||||
|
|
||||||
# ensure_checkpoint_exists(ckpt_path)
|
# ensure_checkpoint_exists(ckpt_path)
|
||||||
path=cfg.paths.inverter_cfg
|
path=cfg.paths.inverter_cfg
|
||||||
inverter=TwoStageInference(opts=path)
|
ckpt = torch.load(path, map_location='cuda:0')
|
||||||
|
opts = ckpt['opts']
|
||||||
return inverter.inverse
|
opts['checkpoint_path'] = path
|
||||||
|
if 'learn_in_w' not in opts:
|
||||||
|
opts['learn_in_w'] = False
|
||||||
|
if 'output_size' not in opts:
|
||||||
|
opts['output_size'] = 1024
|
||||||
|
|
||||||
|
net = pSp(Namespace(**opts))
|
||||||
|
net.eval()
|
||||||
|
net.cuda()
|
||||||
|
return net
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||||
def main(cfg: DictConfig) -> None:
|
def main(cfg: DictConfig) -> None:
|
||||||
model=get_model(cfg)
|
model=get_model(cfg)
|
||||||
train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg)
|
train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg)
|
||||||
|
|
||||||
classifier=get_model(cfg)
|
classifier=get_model(cfg)
|
||||||
|
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
|
||||||
classifier.eval()
|
classifier.eval()
|
||||||
g_ema, _=get_stylegan_generator(cfg)
|
# g_ema, _=get_stylegan_generator(cfg)
|
||||||
|
prompt=get_prompt(cfg.prompt)
|
||||||
|
|
||||||
|
net=get_stylegan_inverter(cfg)
|
||||||
|
model=GanAttack(net,prompt).to(device)
|
||||||
|
|
||||||
inverter=get_stylegan_inverter(cfg)
|
|
||||||
model=GanAttack(g_ema,inverter,images_resize=cfg.optim.images_resize,prompt=cfg.prompt).to(device)
|
|
||||||
|
|
||||||
prompt=get_prompt(cfg)
|
|
||||||
|
|
||||||
num_epochs = cfg.optim.num_epochs
|
num_epochs = cfg.optim.num_epochs
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
@ -72,14 +87,15 @@ def main(cfg: DictConfig) -> None:
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
running_corrects = 0
|
running_corrects = 0
|
||||||
|
|
||||||
for i, (inputs,img_path, labels) in enumerate(train_dataloader):
|
for i, (inputs, labels) in enumerate(train_dataloader):
|
||||||
inputs = inputs.to(device)
|
inputs = inputs.to(device)
|
||||||
labels = labels.to(device)
|
labels = labels.to(device)
|
||||||
_, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
codes = net.encoder(inputs)
|
||||||
|
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path)
|
generated_img,adv_latent_codes=model(inputs)
|
||||||
loss_vgg=vgg_loss(inputs,generated_img)
|
loss_vgg=vgg_loss(inputs,generated_img)
|
||||||
loss_l1=F.l1_loss(clean_latent_codes,adv_latent_codes)
|
loss_l1=F.l1_loss(codes,adv_latent_codes)
|
||||||
loss_clip=clip_loss(generated_img,prompt)
|
loss_clip=clip_loss(generated_img,prompt)
|
||||||
|
|
||||||
_, preds = torch.max(classifier(generated_img), 1)
|
_, preds = torch.max(classifier(generated_img), 1)
|
||||||
|
|
@ -102,11 +118,11 @@ def main(cfg: DictConfig) -> None:
|
||||||
running_loss = 0. #test_dataloader
|
running_loss = 0. #test_dataloader
|
||||||
running_corrects = 0
|
running_corrects = 0
|
||||||
|
|
||||||
for i, (inputs,img_path, labels) in enumerate(test_dataloader):
|
for i, (inputs, labels) in enumerate(test_dataloader):
|
||||||
inputs = inputs.to(device)
|
inputs = inputs.to(device)
|
||||||
labels = labels.to(device)
|
labels = labels.to(device)
|
||||||
|
|
||||||
adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path)
|
generated_img,adv_latent_codes=model(inputs)
|
||||||
outputs = classifier(generated_img)
|
outputs = classifier(generated_img)
|
||||||
_, preds = torch.max(outputs, 1)
|
_, preds = torch.max(outputs, 1)
|
||||||
loss = criterion(outputs, labels)
|
loss = criterion(outputs, labels)
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ classifier:
|
||||||
paths:
|
paths:
|
||||||
gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset
|
gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset
|
||||||
identity_dataset: ./dataset/CelebA_HQ_facial_identity_dataset
|
identity_dataset: ./dataset/CelebA_HQ_facial_identity_dataset
|
||||||
inverter_cfg: secret
|
inverter_cfg: checkpoint/psp_ffhq_encode.pt
|
||||||
classifier: checkpoint
|
classifier: checkpoint
|
||||||
stylegan: checkpoint/stylegan2-ffhq-config-f.pt
|
stylegan: checkpoint/stylegan2-ffhq-config-f.pt
|
||||||
adv_embedding: pretrained_models
|
adv_embedding: pretrained_models
|
||||||
|
|
|
||||||
28
model.py
28
model.py
|
|
@ -20,13 +20,13 @@ def get_prompt(cfg):
|
||||||
# class GanAttack(nn.Module):
|
# class GanAttack(nn.Module):
|
||||||
|
|
||||||
class GanAttack(nn.Module):
|
class GanAttack(nn.Module):
|
||||||
def __init__(self, stylegan_generator, inverter,images_resize,prompt):
|
def __init__(self, net,prompt):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.generator = stylegan_generator
|
self.net= net
|
||||||
self.generator.eval()
|
# self.generator.eval()
|
||||||
self.inverter=inverter
|
# self.inverter=inverter
|
||||||
self.images_resize=images_resize
|
# self.images_resize=images_resize
|
||||||
self.prompt=prompt
|
self.prompt=prompt
|
||||||
text_len=self.prompt.shape[0]
|
text_len=self.prompt.shape[0]
|
||||||
self.mlp=nn.Sequential(
|
self.mlp=nn.Sequential(
|
||||||
|
|
@ -36,18 +36,22 @@ class GanAttack(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, img,img_path):
|
def forward(self, img):
|
||||||
_, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
|
codes = self.net.encoder(img)
|
||||||
x=latent_codes
|
codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
|
||||||
|
# result_images, result_latent = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=False)
|
||||||
|
# result_images = self.net.face_pool(result_images)
|
||||||
|
# _, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
|
||||||
|
x=codes
|
||||||
batch_size=img.shape[0]
|
batch_size=img.shape[0]
|
||||||
prompt=self.prompt.repeat(batch_size).to(device)
|
prompt=self.prompt.repeat(batch_size).to(device)
|
||||||
x_prompt=torch.cat([latent_codes,prompt],dim=1)
|
x_prompt=torch.cat([codes,prompt],dim=1)
|
||||||
x_prompt=self.mlp(x_prompt)
|
x_prompt=self.mlp(x_prompt)
|
||||||
x=x_prompt+x
|
x=x_prompt+x
|
||||||
im,_=self.generator(x,input_is_latent=True, randomize_noise=False)
|
im,_=self.net.decoder([x], input_is_latent=True, randomize_noise=False, return_latents=False)
|
||||||
|
result_images = self.net.face_pool(im)
|
||||||
|
|
||||||
|
return result_images,x
|
||||||
return refine_images,im,x
|
|
||||||
|
|
||||||
class CLIPLoss(torch.nn.Module):
|
class CLIPLoss(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue