update all

This commit is contained in:
leewlving 2023-12-06 20:36:49 +08:00
parent 59f2cad553
commit cf67f37cd3
3 changed files with 62 additions and 42 deletions

View File

@ -1,16 +1,20 @@
import sys import sys
sys.path.append('./GanInverter') import os
sys.path.append('./pixel2style2pixel')
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torchvision import models from torchvision import models
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from data.dataset import get_dataset,get_adv_dataset from data.dataset import get_dataset,get_adv_dataset
from utils import get_model,set_requires_grad from utils import get_model,set_requires_grad,unnormalize
from GanInverter.inference.two_stage_inference import TwoStageInference # from GanInverter.inference.two_stage_inference import TwoStageInference
from GanInverter.models.stylegan2.model import Generator # from GanInverter.models.stylegan2.model import Generator
# from models import GanAttack # from models import GanAttack
import model from argparse import Namespace
# from pixel2style2pixel.scripts.align_all_parallel import align_face
from pixel2style2pixel.models.psp import pSp
# from pixel2style2pixel.models.stylegan2.model import Generator
from model import GanAttack,CLIPLoss,VggLoss,get_prompt from model import GanAttack,CLIPLoss,VggLoss,get_prompt
import torch.nn.functional as F import torch.nn.functional as F
import time import time
@ -22,39 +26,50 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_stylegan_generator(cfg): # def get_stylegan_generator(cfg):
# ensure_checkpoint_exists(ckpt_path) # # ensure_checkpoint_exists(ckpt_path)
ckpt_path=cfg.paths.stylegan # ckpt_path=cfg.paths.stylegan
g_ema = Generator(1024, 512, 8) # g_ema = Generator(1024, 512, 8)
g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False) # g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
g_ema.eval() # g_ema.eval()
g_ema = g_ema.cuda() # g_ema = g_ema.cuda()
mean_latent = g_ema.mean_latent(4096) # mean_latent = g_ema.mean_latent(4096)
return g_ema, mean_latent # return g_ema, mean_latent
def get_stylegan_inverter(cfg): def get_stylegan_inverter(cfg):
# ensure_checkpoint_exists(ckpt_path) # ensure_checkpoint_exists(ckpt_path)
path=cfg.paths.inverter_cfg path=cfg.paths.inverter_cfg
inverter=TwoStageInference(opts=path) ckpt = torch.load(path, map_location='cuda:0')
opts = ckpt['opts']
opts['checkpoint_path'] = path
if 'learn_in_w' not in opts:
opts['learn_in_w'] = False
if 'output_size' not in opts:
opts['output_size'] = 1024
return inverter.inverse net = pSp(Namespace(**opts))
net.eval()
net.cuda()
return net
@hydra.main(version_base=None, config_path="./config", config_name="config") @hydra.main(version_base=None, config_path="./config", config_name="config")
def main(cfg: DictConfig) -> None: def main(cfg: DictConfig) -> None:
model=get_model(cfg) model=get_model(cfg)
train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg) train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg)
classifier=get_model(cfg) classifier=get_model(cfg)
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
classifier.eval() classifier.eval()
g_ema, _=get_stylegan_generator(cfg) # g_ema, _=get_stylegan_generator(cfg)
prompt=get_prompt(cfg.prompt)
net=get_stylegan_inverter(cfg)
model=GanAttack(net,prompt).to(device)
inverter=get_stylegan_inverter(cfg)
model=GanAttack(g_ema,inverter,images_resize=cfg.optim.images_resize,prompt=cfg.prompt).to(device)
prompt=get_prompt(cfg)
num_epochs = cfg.optim.num_epochs num_epochs = cfg.optim.num_epochs
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
@ -72,14 +87,15 @@ def main(cfg: DictConfig) -> None:
running_loss = 0 running_loss = 0
running_corrects = 0 running_corrects = 0
for i, (inputs,img_path, labels) in enumerate(train_dataloader): for i, (inputs, labels) in enumerate(train_dataloader):
inputs = inputs.to(device) inputs = inputs.to(device)
labels = labels.to(device) labels = labels.to(device)
_, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path) codes = net.encoder(inputs)
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
optimizer.zero_grad() optimizer.zero_grad()
adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path) generated_img,adv_latent_codes=model(inputs)
loss_vgg=vgg_loss(inputs,generated_img) loss_vgg=vgg_loss(inputs,generated_img)
loss_l1=F.l1_loss(clean_latent_codes,adv_latent_codes) loss_l1=F.l1_loss(codes,adv_latent_codes)
loss_clip=clip_loss(generated_img,prompt) loss_clip=clip_loss(generated_img,prompt)
_, preds = torch.max(classifier(generated_img), 1) _, preds = torch.max(classifier(generated_img), 1)
@ -102,11 +118,11 @@ def main(cfg: DictConfig) -> None:
running_loss = 0. #test_dataloader running_loss = 0. #test_dataloader
running_corrects = 0 running_corrects = 0
for i, (inputs,img_path, labels) in enumerate(test_dataloader): for i, (inputs, labels) in enumerate(test_dataloader):
inputs = inputs.to(device) inputs = inputs.to(device)
labels = labels.to(device) labels = labels.to(device)
adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path) generated_img,adv_latent_codes=model(inputs)
outputs = classifier(generated_img) outputs = classifier(generated_img)
_, preds = torch.max(outputs, 1) _, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels) loss = criterion(outputs, labels)

View File

@ -11,7 +11,7 @@ classifier:
paths: paths:
gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset
identity_dataset: ./dataset/CelebA_HQ_facial_identity_dataset identity_dataset: ./dataset/CelebA_HQ_facial_identity_dataset
inverter_cfg: secret inverter_cfg: checkpoint/psp_ffhq_encode.pt
classifier: checkpoint classifier: checkpoint
stylegan: checkpoint/stylegan2-ffhq-config-f.pt stylegan: checkpoint/stylegan2-ffhq-config-f.pt
adv_embedding: pretrained_models adv_embedding: pretrained_models

View File

@ -20,13 +20,13 @@ def get_prompt(cfg):
# class GanAttack(nn.Module): # class GanAttack(nn.Module):
class GanAttack(nn.Module): class GanAttack(nn.Module):
def __init__(self, stylegan_generator, inverter,images_resize,prompt): def __init__(self, net,prompt):
super().__init__() super().__init__()
self.generator = stylegan_generator self.net= net
self.generator.eval() # self.generator.eval()
self.inverter=inverter # self.inverter=inverter
self.images_resize=images_resize # self.images_resize=images_resize
self.prompt=prompt self.prompt=prompt
text_len=self.prompt.shape[0] text_len=self.prompt.shape[0]
self.mlp=nn.Sequential( self.mlp=nn.Sequential(
@ -36,18 +36,22 @@ class GanAttack(nn.Module):
) )
def forward(self, img,img_path): def forward(self, img):
_, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path) codes = self.net.encoder(img)
x=latent_codes codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
# result_images, result_latent = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=False)
# result_images = self.net.face_pool(result_images)
# _, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
x=codes
batch_size=img.shape[0] batch_size=img.shape[0]
prompt=self.prompt.repeat(batch_size).to(device) prompt=self.prompt.repeat(batch_size).to(device)
x_prompt=torch.cat([latent_codes,prompt],dim=1) x_prompt=torch.cat([codes,prompt],dim=1)
x_prompt=self.mlp(x_prompt) x_prompt=self.mlp(x_prompt)
x=x_prompt+x x=x_prompt+x
im,_=self.generator(x,input_is_latent=True, randomize_noise=False) im,_=self.net.decoder([x], input_is_latent=True, randomize_noise=False, return_latents=False)
result_images = self.net.face_pool(im)
return result_images,x
return refine_images,im,x
class CLIPLoss(torch.nn.Module): class CLIPLoss(torch.nn.Module):
def __init__(self): def __init__(self):