add others
This commit is contained in:
parent
4f6194e2af
commit
dfd67482a1
50
GanAttack.py
50
GanAttack.py
|
|
@ -3,11 +3,12 @@ import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torchvision import models
|
from torchvision import models
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from data.dataset import get_dataset
|
from data.dataset import get_dataset,get_adv_dataset
|
||||||
from utils import get_model
|
from utils import get_model,set_requires_grad
|
||||||
from model.GanInverter.models.stylegan2.model import Generator
|
from model.GanInverter.models.stylegan2.model import Generator
|
||||||
from model.GanInverter.inference.two_stage_inference import TwoStageInference
|
from model.GanInverter.inference.two_stage_inference import TwoStageInference
|
||||||
from model import GanAttack
|
from model import GanAttack,CLIPLoss,VggLoss,get_prompt
|
||||||
|
import torch.nn.functional as F
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
|
|
@ -38,17 +39,25 @@ def get_stylegan_inverter(cfg):
|
||||||
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||||
def main(cfg: DictConfig) -> None:
|
def main(cfg: DictConfig) -> None:
|
||||||
model=get_model(cfg)
|
model=get_model(cfg)
|
||||||
train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg)
|
train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg)
|
||||||
|
|
||||||
classifier=get_model(cfg)
|
classifier=get_model(cfg)
|
||||||
|
classifier.eval()
|
||||||
g_ema, _=get_stylegan_generator(cfg)
|
g_ema, _=get_stylegan_generator(cfg)
|
||||||
|
|
||||||
inverter=get_stylegan_inverter(cfg)
|
inverter=get_stylegan_inverter(cfg)
|
||||||
model=GanAttack(g_ema,inverter,images_resize=cfg.optim.images_resize,prompt=cfg.prompt).to(device)
|
model=GanAttack(g_ema,inverter,images_resize=cfg.optim.images_resize,prompt=cfg.prompt).to(device)
|
||||||
|
|
||||||
|
prompt=get_prompt(cfg)
|
||||||
|
|
||||||
num_epochs = cfg.optim.num_epochs
|
num_epochs = cfg.optim.num_epochs
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = optim.SGD(model.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
|
max_loss=nn.MarginRankingLoss(0.1)
|
||||||
|
clip_loss=CLIPLoss().to(device)
|
||||||
|
vgg_loss=VggLoss().to(device)
|
||||||
|
|
||||||
|
set_requires_grad(model.mlp.parameters())
|
||||||
|
optimizer = optim.SGD(model.mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
|
|
@ -57,15 +66,19 @@ def main(cfg: DictConfig) -> None:
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
running_corrects = 0
|
running_corrects = 0
|
||||||
|
|
||||||
for i, (inputs, labels) in enumerate(train_dataloader):
|
for i, (inputs,img_path, labels) in enumerate(train_dataloader):
|
||||||
inputs = inputs.to(device)
|
inputs = inputs.to(device)
|
||||||
labels = labels.to(device)
|
labels = labels.to(device)
|
||||||
|
_, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
outputs = model(inputs)
|
adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path)
|
||||||
_, preds = torch.max(outputs, 1)
|
loss_vgg=vgg_loss(inputs,generated_img)
|
||||||
|
loss_l1=F.l1_loss(clean_latent_codes,adv_latent_codes)
|
||||||
|
loss_clip=clip_loss(generated_img,prompt)
|
||||||
|
|
||||||
loss = criterion(outputs, labels)
|
_, preds = torch.max(classifier(generated_img), 1)
|
||||||
|
loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
|
||||||
|
loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
@ -78,27 +91,28 @@ def main(cfg: DictConfig) -> None:
|
||||||
|
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
print('Evaluating!')
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
running_loss = 0.
|
running_loss = 0. #test_dataloader
|
||||||
running_corrects = 0
|
running_corrects = 0
|
||||||
|
|
||||||
for inputs, labels in test_dataloader:
|
for i, (inputs,img_path, labels) in enumerate(test_dataloader):
|
||||||
inputs = inputs.to(device)
|
inputs = inputs.to(device)
|
||||||
labels = labels.to(device)
|
labels = labels.to(device)
|
||||||
|
|
||||||
outputs = model(inputs)
|
adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path)
|
||||||
|
outputs = classifier(generated_img)
|
||||||
_, preds = torch.max(outputs, 1)
|
_, preds = torch.max(outputs, 1)
|
||||||
loss = criterion(outputs, labels)
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
running_loss += loss.item() * inputs.size(0)
|
# running_loss += loss.item() * inputs.size(0)
|
||||||
running_corrects += torch.sum(preds == labels.data)
|
running_corrects += torch.sum(preds == labels.data)
|
||||||
|
|
||||||
epoch_loss = running_loss / len(test_dataset)
|
# epoch_loss = running_loss / len(test_dataset)
|
||||||
epoch_acc = running_corrects / len(test_dataset) * 100.
|
epoch_acc = running_corrects / len(test_dataset) * 100.
|
||||||
print('[Test #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
|
print('[Test #{}] Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_acc, time.time() - start_time))
|
||||||
|
|
||||||
save_path = '{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)
|
save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt)
|
||||||
torch.save(model.state_dict(), save_path)
|
torch.save(model.state_dict(), save_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ classifier:
|
||||||
lr: 0.01
|
lr: 0.01
|
||||||
momentum: 0.9
|
momentum: 0.9
|
||||||
num_epochs: 200
|
num_epochs: 200
|
||||||
|
num_workers : 4
|
||||||
|
|
||||||
|
|
||||||
paths:
|
paths:
|
||||||
|
|
@ -12,6 +13,7 @@ paths:
|
||||||
inverter_cfg: secret
|
inverter_cfg: secret
|
||||||
classifier: checkpoint/
|
classifier: checkpoint/
|
||||||
stylegan: pretrained_models/stylegan2-ffhq-config-f.pt
|
stylegan: pretrained_models/stylegan2-ffhq-config-f.pt
|
||||||
|
adv_embedding: pretrained_models
|
||||||
|
|
||||||
prompt: red lipstick
|
prompt: red lipstick
|
||||||
# available attributes
|
# available attributes
|
||||||
|
|
@ -24,3 +26,6 @@ optim:
|
||||||
num_epochs: 200
|
num_epochs: 200
|
||||||
num_workers : 4
|
num_workers : 4
|
||||||
images_resize: 256
|
images_resize: 256
|
||||||
|
alpha: 0.1
|
||||||
|
beta: 1
|
||||||
|
delta: 1
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from torchvision import datasets, models, transforms
|
from torchvision import datasets, models, transforms
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from PIL import Image
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import pathlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
transforms_train = transforms.Compose([
|
transforms_train = transforms.Compose([
|
||||||
|
|
@ -17,6 +20,35 @@ transforms_test = transforms.Compose([
|
||||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
||||||
])
|
])
|
||||||
|
|
||||||
|
class ImageDataset(Dataset):
|
||||||
|
def __init__(self, data_path, mode, transform=None):
|
||||||
|
self.path=data_path
|
||||||
|
data_dir=pathlib.Path(data_path)
|
||||||
|
self.mode=mode
|
||||||
|
self.transform=transform
|
||||||
|
if self.mode == 'train':
|
||||||
|
self.image_path=list(data_dir.glob("train/*/*"))
|
||||||
|
self.image_path=[str(path) for path in self.image_path]
|
||||||
|
else:
|
||||||
|
self.image_path=list(data_dir.glob("test/*/*"))
|
||||||
|
self.image_path=[str(path) for path in self.image_path]
|
||||||
|
|
||||||
|
lable_names = sorted(item.name for item in data_dir.glob("train/*/"))
|
||||||
|
lable_to_index = dict((name, index) for index, name in enumerate(lable_names))
|
||||||
|
self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path]
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
img = Image.open(os.path.join(self.path, self.image_path[index]))
|
||||||
|
img = img.convert('RGB')
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
label = torch.LongTensor([self.image_label[index]])
|
||||||
|
image_path=self.image_path[index]
|
||||||
|
return img, image_path ,label
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_path)
|
||||||
|
|
||||||
def get_dataset(config):
|
def get_dataset(config):
|
||||||
if config.dataset == 'gender_dataset':
|
if config.dataset == 'gender_dataset':
|
||||||
|
|
@ -26,6 +58,23 @@ def get_dataset(config):
|
||||||
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
|
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
|
||||||
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
|
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
|
||||||
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.classifier.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
||||||
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers)
|
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.classifier.batch_size, shuffle=False, num_workers=config.optim.num_workers)
|
||||||
return train_dataloader,test_dataloader,train_dataset,test_dataset
|
return train_dataloader,test_dataloader,train_dataset,test_dataset
|
||||||
|
|
||||||
|
def get_adv_dataset(config):
|
||||||
|
if config.dataset == 'gender_dataset':
|
||||||
|
path=config.paths.gender_dataset
|
||||||
|
else:
|
||||||
|
path=config.paths.identity_dataset
|
||||||
|
train_dataset = ImageDataset(path,'train',transforms_train)
|
||||||
|
test_dataset= ImageDataset(path,'test',transforms_test)
|
||||||
|
|
||||||
|
train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
||||||
|
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
||||||
|
|
||||||
|
return train_dataloader,test_dataloader,train_dataset,test_dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
39
model.py
39
model.py
|
|
@ -2,8 +2,10 @@ import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from torchvision import datasets, models, transforms
|
from torchvision import datasets, models, transforms
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
import os
|
import os
|
||||||
import clip
|
import clip
|
||||||
|
from utils import normalize
|
||||||
|
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
@ -45,4 +47,39 @@ class GanAttack(nn.Module):
|
||||||
im,_=self.generator(x,input_is_latent=True, randomize_noise=False)
|
im,_=self.generator(x,input_is_latent=True, randomize_noise=False)
|
||||||
|
|
||||||
|
|
||||||
return img,refine_images,im,x
|
return refine_images,im,x
|
||||||
|
|
||||||
|
class CLIPLoss(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(CLIPLoss, self).__init__()
|
||||||
|
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
|
||||||
|
self.model.eval()
|
||||||
|
self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
|
||||||
|
# self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
|
||||||
|
# self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
|
||||||
|
|
||||||
|
def forward(self, image, text):
|
||||||
|
image=normalize(image)
|
||||||
|
image = self.face_pool(image)
|
||||||
|
similarity = 1 - self.model(image, text)[0]/ 100
|
||||||
|
return similarity
|
||||||
|
|
||||||
|
|
||||||
|
class VggLoss(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(VggLoss, self).__init__()
|
||||||
|
self.model=models.vgg11(pretrained=True)
|
||||||
|
self.model.features=nn.Sequential()
|
||||||
|
|
||||||
|
# self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
|
||||||
|
# self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
|
||||||
|
|
||||||
|
def forward(self, image1, image2):
|
||||||
|
# image=normalize(image)
|
||||||
|
with torch.no_grad:
|
||||||
|
feature1=self.model(image1)
|
||||||
|
feature2=self.model(image2)
|
||||||
|
feature1=torch.flatten(feature1)
|
||||||
|
feature2=torch.flatten(feature2)
|
||||||
|
similarity = F.cosine_similarity(feature1,feature2)
|
||||||
|
return similarity
|
||||||
Loading…
Reference in New Issue