update all

This commit is contained in:
leewlving 2023-12-06 20:36:49 +08:00
parent 59f2cad553
commit cf67f37cd3
3 changed files with 62 additions and 42 deletions

View File

@ -1,16 +1,20 @@
import sys
sys.path.append('./GanInverter')
import os
sys.path.append('./pixel2style2pixel')
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from omegaconf import DictConfig, OmegaConf
from data.dataset import get_dataset,get_adv_dataset
from utils import get_model,set_requires_grad
from GanInverter.inference.two_stage_inference import TwoStageInference
from GanInverter.models.stylegan2.model import Generator
from utils import get_model,set_requires_grad,unnormalize
# from GanInverter.inference.two_stage_inference import TwoStageInference
# from GanInverter.models.stylegan2.model import Generator
# from models import GanAttack
import model
from argparse import Namespace
# from pixel2style2pixel.scripts.align_all_parallel import align_face
from pixel2style2pixel.models.psp import pSp
# from pixel2style2pixel.models.stylegan2.model import Generator
from model import GanAttack,CLIPLoss,VggLoss,get_prompt
import torch.nn.functional as F
import time
@ -22,39 +26,50 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_stylegan_generator(cfg):
# def get_stylegan_generator(cfg):
# ensure_checkpoint_exists(ckpt_path)
ckpt_path=cfg.paths.stylegan
g_ema = Generator(1024, 512, 8)
g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
g_ema.eval()
g_ema = g_ema.cuda()
mean_latent = g_ema.mean_latent(4096)
# # ensure_checkpoint_exists(ckpt_path)
# ckpt_path=cfg.paths.stylegan
# g_ema = Generator(1024, 512, 8)
# g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
# g_ema.eval()
# g_ema = g_ema.cuda()
# mean_latent = g_ema.mean_latent(4096)
return g_ema, mean_latent
# return g_ema, mean_latent
def get_stylegan_inverter(cfg):
# ensure_checkpoint_exists(ckpt_path)
path=cfg.paths.inverter_cfg
inverter=TwoStageInference(opts=path)
ckpt = torch.load(path, map_location='cuda:0')
opts = ckpt['opts']
opts['checkpoint_path'] = path
if 'learn_in_w' not in opts:
opts['learn_in_w'] = False
if 'output_size' not in opts:
opts['output_size'] = 1024
return inverter.inverse
net = pSp(Namespace(**opts))
net.eval()
net.cuda()
return net
@hydra.main(version_base=None, config_path="./config", config_name="config")
def main(cfg: DictConfig) -> None:
model=get_model(cfg)
train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg)
train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg)
classifier=get_model(cfg)
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
classifier.eval()
g_ema, _=get_stylegan_generator(cfg)
# g_ema, _=get_stylegan_generator(cfg)
prompt=get_prompt(cfg.prompt)
net=get_stylegan_inverter(cfg)
model=GanAttack(net,prompt).to(device)
inverter=get_stylegan_inverter(cfg)
model=GanAttack(g_ema,inverter,images_resize=cfg.optim.images_resize,prompt=cfg.prompt).to(device)
prompt=get_prompt(cfg)
num_epochs = cfg.optim.num_epochs
criterion = nn.CrossEntropyLoss()
@ -72,14 +87,15 @@ def main(cfg: DictConfig) -> None:
running_loss = 0
running_corrects = 0
for i, (inputs,img_path, labels) in enumerate(train_dataloader):
for i, (inputs, labels) in enumerate(train_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
_, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
codes = net.encoder(inputs)
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
optimizer.zero_grad()
adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path)
generated_img,adv_latent_codes=model(inputs)
loss_vgg=vgg_loss(inputs,generated_img)
loss_l1=F.l1_loss(clean_latent_codes,adv_latent_codes)
loss_l1=F.l1_loss(codes,adv_latent_codes)
loss_clip=clip_loss(generated_img,prompt)
_, preds = torch.max(classifier(generated_img), 1)
@ -102,11 +118,11 @@ def main(cfg: DictConfig) -> None:
running_loss = 0. #test_dataloader
running_corrects = 0
for i, (inputs,img_path, labels) in enumerate(test_dataloader):
for i, (inputs, labels) in enumerate(test_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path)
generated_img,adv_latent_codes=model(inputs)
outputs = classifier(generated_img)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)

View File

@ -11,7 +11,7 @@ classifier:
paths:
gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset
identity_dataset: ./dataset/CelebA_HQ_facial_identity_dataset
inverter_cfg: secret
inverter_cfg: checkpoint/psp_ffhq_encode.pt
classifier: checkpoint
stylegan: checkpoint/stylegan2-ffhq-config-f.pt
adv_embedding: pretrained_models

View File

@ -20,13 +20,13 @@ def get_prompt(cfg):
# class GanAttack(nn.Module):
class GanAttack(nn.Module):
def __init__(self, stylegan_generator, inverter,images_resize,prompt):
def __init__(self, net,prompt):
super().__init__()
self.generator = stylegan_generator
self.generator.eval()
self.inverter=inverter
self.images_resize=images_resize
self.net= net
# self.generator.eval()
# self.inverter=inverter
# self.images_resize=images_resize
self.prompt=prompt
text_len=self.prompt.shape[0]
self.mlp=nn.Sequential(
@ -36,18 +36,22 @@ class GanAttack(nn.Module):
)
def forward(self, img,img_path):
_, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
x=latent_codes
def forward(self, img):
codes = self.net.encoder(img)
codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
# result_images, result_latent = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=False)
# result_images = self.net.face_pool(result_images)
# _, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
x=codes
batch_size=img.shape[0]
prompt=self.prompt.repeat(batch_size).to(device)
x_prompt=torch.cat([latent_codes,prompt],dim=1)
x_prompt=torch.cat([codes,prompt],dim=1)
x_prompt=self.mlp(x_prompt)
x=x_prompt+x
im,_=self.generator(x,input_is_latent=True, randomize_noise=False)
im,_=self.net.decoder([x], input_is_latent=True, randomize_noise=False, return_latents=False)
result_images = self.net.face_pool(im)
return refine_images,im,x
return result_images,x
class CLIPLoss(torch.nn.Module):
def __init__(self):