new updated from remote
This commit is contained in:
parent
a6080d42a7
commit
63def0a779
17
GanAttack.py
17
GanAttack.py
|
|
@ -18,6 +18,7 @@ from pixel2style2pixel.models.psp import pSp
|
||||||
from model import GanAttack,CLIPLoss,VggLoss,get_prompt
|
from model import GanAttack,CLIPLoss,VggLoss,get_prompt
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import time
|
import time
|
||||||
|
# from torchsummary import summary
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
|
|
||||||
|
|
@ -64,7 +65,7 @@ def main(cfg: DictConfig) -> None:
|
||||||
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
|
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
|
||||||
classifier.eval()
|
classifier.eval()
|
||||||
# g_ema, _=get_stylegan_generator(cfg)
|
# g_ema, _=get_stylegan_generator(cfg)
|
||||||
prompt=get_prompt(cfg.prompt)
|
prompt=get_prompt(cfg)
|
||||||
|
|
||||||
net=get_stylegan_inverter(cfg)
|
net=get_stylegan_inverter(cfg)
|
||||||
model=GanAttack(net,prompt).to(device)
|
model=GanAttack(net,prompt).to(device)
|
||||||
|
|
@ -75,9 +76,11 @@ def main(cfg: DictConfig) -> None:
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
max_loss=nn.MarginRankingLoss(0.1)
|
max_loss=nn.MarginRankingLoss(0.1)
|
||||||
clip_loss=CLIPLoss().to(device)
|
clip_loss=CLIPLoss().to(device)
|
||||||
vgg_loss=VggLoss().to(device)
|
# vgg_loss=VggLoss().to(device)
|
||||||
|
# summary(model, input_size = (3, 256, 256), batch_size = 5)
|
||||||
set_requires_grad(model.mlp.parameters())
|
# set_requires_grad(model.mlp.parameters())
|
||||||
|
for p in (model.mlp.parameters()):
|
||||||
|
p.requires_grad =True
|
||||||
optimizer = optim.SGD(model.mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
|
optimizer = optim.SGD(model.mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
@ -95,7 +98,7 @@ def main(cfg: DictConfig) -> None:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
generated_img,adv_latent_codes=model(inputs)
|
generated_img,adv_latent_codes=model(inputs)
|
||||||
loss_vgg=vgg_loss(inputs,generated_img)
|
loss_vgg=vgg_loss(inputs,generated_img)
|
||||||
loss_l1=F.l1_loss(codes,adv_latent_codes)
|
# loss_l1=F.l1_loss(codes,adv_latent_codes)
|
||||||
loss_clip=clip_loss(generated_img,prompt)
|
loss_clip=clip_loss(generated_img,prompt)
|
||||||
outputs = classifier(generated_img)
|
outputs = classifier(generated_img)
|
||||||
preds=criterion(outputs,labels)
|
preds=criterion(outputs,labels)
|
||||||
|
|
@ -103,7 +106,9 @@ def main(cfg: DictConfig) -> None:
|
||||||
# _, preds = torch.max(classifier(generated_img), 1)
|
# _, preds = torch.max(classifier(generated_img), 1)
|
||||||
# loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
|
# loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
|
||||||
|
|
||||||
loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||||
|
# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||||
|
loss=loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,17 @@
|
||||||
dataset: gender_dataset
|
dataset: identity_dataset
|
||||||
classifier:
|
classifier:
|
||||||
model: resnet18
|
model: resnet18
|
||||||
lr: 0.01
|
lr: 0.01
|
||||||
momentum: 0.9
|
momentum: 0.9
|
||||||
num_epochs: 25
|
num_epochs: 20
|
||||||
num_workers : 4
|
num_workers : 4
|
||||||
batch_size: 64
|
batch_size: 64
|
||||||
|
|
||||||
|
|
||||||
paths:
|
paths:
|
||||||
gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset
|
gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset
|
||||||
identity_dataset: ./dataset/CelebA_HQ_facial_identity_dataset
|
identity_dataset: /dataset/face_identity/CelebA_HQ_facial_identity_dataset
|
||||||
inverter_cfg: checkpoint/psp_ffhq_encode.pt
|
inverter_cfg: /dataset/face_identity/psp_ffhq_encode.pt
|
||||||
classifier: checkpoint
|
classifier: checkpoint
|
||||||
stylegan: checkpoint/stylegan2-ffhq-config-f.pt
|
stylegan: checkpoint/stylegan2-ffhq-config-f.pt
|
||||||
adv_embedding: pretrained_models
|
adv_embedding: pretrained_models
|
||||||
|
|
@ -23,7 +23,7 @@ prompt: red lipstick
|
||||||
# 'Arched_Eyebrows', 'Bangs', 'Wearing_Earrings', 'Bags_Under_Eyes', 'Receding_Hairline', 'Pale_Skin']
|
# 'Arched_Eyebrows', 'Bangs', 'Wearing_Earrings', 'Bags_Under_Eyes', 'Receding_Hairline', 'Pale_Skin']
|
||||||
|
|
||||||
optim:
|
optim:
|
||||||
batch_size: 8
|
batch_size: 1
|
||||||
num_epochs: 200
|
num_epochs: 200
|
||||||
num_workers : 4
|
num_workers : 4
|
||||||
images_resize: 256
|
images_resize: 256
|
||||||
|
|
|
||||||
18
model.py
18
model.py
|
|
@ -11,7 +11,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
def get_prompt(cfg):
|
def get_prompt(cfg):
|
||||||
model, preprocess = clip.load("ViT-B/32", device=device)
|
model, preprocess = clip.load("RN50", device=device)
|
||||||
prompt=cfg.prompt
|
prompt=cfg.prompt
|
||||||
text=clip.tokenize(prompt).to(device)
|
text=clip.tokenize(prompt).to(device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
@ -28,9 +28,9 @@ class GanAttack(nn.Module):
|
||||||
# self.inverter=inverter
|
# self.inverter=inverter
|
||||||
# self.images_resize=images_resize
|
# self.images_resize=images_resize
|
||||||
self.prompt=prompt
|
self.prompt=prompt
|
||||||
text_len=self.prompt.shape[0]
|
text_len=self.prompt.shape[1]
|
||||||
self.mlp=nn.Sequential(
|
self.mlp=nn.Sequential(
|
||||||
nn.Linear(text_len+512, 4096),
|
nn.Linear(1024+512, 4096),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Linear(4096, 512)
|
nn.Linear(4096, 512)
|
||||||
)
|
)
|
||||||
|
|
@ -44,11 +44,15 @@ class GanAttack(nn.Module):
|
||||||
# _, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
|
# _, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
|
||||||
x=codes
|
x=codes
|
||||||
batch_size=img.shape[0]
|
batch_size=img.shape[0]
|
||||||
prompt=self.prompt.repeat(batch_size).to(device)
|
# print(img.shape)
|
||||||
x_prompt=torch.cat([codes,prompt],dim=1)
|
# print(self.prompt.shape)
|
||||||
|
# print(codes.shape)
|
||||||
|
prompt=self.prompt.repeat(batch_size,18,1).to(device)
|
||||||
|
print(prompt.shape)
|
||||||
|
x_prompt=torch.cat([codes,prompt],dim=2)
|
||||||
x_prompt=self.mlp(x_prompt)
|
x_prompt=self.mlp(x_prompt)
|
||||||
x=x_prompt+x
|
x=x_prompt+x
|
||||||
im,_=self.net.decoder([x], input_is_latent=True, randomize_noise=False, return_latents=False)
|
im,_=self.net.decoder(x, input_is_latent=True, randomize_noise=False, return_latents=False)
|
||||||
result_images = self.net.face_pool(im)
|
result_images = self.net.face_pool(im)
|
||||||
|
|
||||||
return result_images,x
|
return result_images,x
|
||||||
|
|
@ -56,7 +60,7 @@ class GanAttack(nn.Module):
|
||||||
class CLIPLoss(torch.nn.Module):
|
class CLIPLoss(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(CLIPLoss, self).__init__()
|
super(CLIPLoss, self).__init__()
|
||||||
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
|
self.model, self.preprocess = clip.load("RN50", device="cuda")
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
|
self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
|
||||||
# self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
|
# self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torchvision import models
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
from data.dataset import get_dataset,get_adv_dataset
|
||||||
|
from utils import get_model,set_requires_grad,unnormalize
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import clip
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import time
|
||||||
|
import hydra
|
||||||
|
from scipy import io as spio
|
||||||
|
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||||
|
def save_prompt(cfg):
|
||||||
|
model, preprocess = clip.load("RN50", device=device)
|
||||||
|
prompt=cfg.prompt
|
||||||
|
text=clip.tokenize(prompt).to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
prompt = model.encode_text(text)
|
||||||
|
prompt=prompt.cpu().numpy()
|
||||||
|
spio.savemat("prompt.mat",{'prompt':prompt})
|
||||||
|
|
||||||
|
def get_prompt(cfg):
|
||||||
|
data=spio.loadmat("prompt.mat")
|
||||||
|
prompt=data['prompt']
|
||||||
|
return torch.from_numpy(prompt).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
save_prompt()
|
||||||
Loading…
Reference in New Issue