diff --git a/GanAttack.py b/GanAttack.py index 290e024..700c760 100644 --- a/GanAttack.py +++ b/GanAttack.py @@ -3,11 +3,12 @@ import torch.nn as nn import torch.optim as optim from torchvision import models from omegaconf import DictConfig, OmegaConf -from data.dataset import get_dataset -from utils import get_model +from data.dataset import get_dataset,get_adv_dataset +from utils import get_model,set_requires_grad from model.GanInverter.models.stylegan2.model import Generator from model.GanInverter.inference.two_stage_inference import TwoStageInference -from model import GanAttack +from model import GanAttack,CLIPLoss,VggLoss,get_prompt +import torch.nn.functional as F import time import hydra @@ -38,17 +39,25 @@ def get_stylegan_inverter(cfg): @hydra.main(version_base=None, config_path="./config", config_name="config") def main(cfg: DictConfig) -> None: model=get_model(cfg) - train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg) + train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg) classifier=get_model(cfg) + classifier.eval() g_ema, _=get_stylegan_generator(cfg) inverter=get_stylegan_inverter(cfg) model=GanAttack(g_ema,inverter,images_resize=cfg.optim.images_resize,prompt=cfg.prompt).to(device) + prompt=get_prompt(cfg) + num_epochs = cfg.optim.num_epochs criterion = nn.CrossEntropyLoss() - optimizer = optim.SGD(model.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum) + max_loss=nn.MarginRankingLoss(0.1) + clip_loss=CLIPLoss().to(device) + vgg_loss=VggLoss().to(device) + + set_requires_grad(model.mlp.parameters()) + optimizer = optim.SGD(model.mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum) start_time = time.time() for epoch in range(num_epochs): @@ -57,15 +66,19 @@ def main(cfg: DictConfig) -> None: running_loss = 0 running_corrects = 0 - for i, (inputs, labels) in enumerate(train_dataloader): + for i, (inputs,img_path, labels) in enumerate(train_dataloader): inputs = inputs.to(device) labels = labels.to(device) - + _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path) optimizer.zero_grad() - outputs = model(inputs) - _, preds = torch.max(outputs, 1) - - loss = criterion(outputs, labels) + adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path) + loss_vgg=vgg_loss(inputs,generated_img) + loss_l1=F.l1_loss(clean_latent_codes,adv_latent_codes) + loss_clip=clip_loss(generated_img,prompt) + + _, preds = torch.max(classifier(generated_img), 1) + loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels)) + loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier loss.backward() optimizer.step() @@ -78,27 +91,28 @@ def main(cfg: DictConfig) -> None: model.eval() - + print('Evaluating!') with torch.no_grad(): - running_loss = 0. + running_loss = 0. #test_dataloader running_corrects = 0 - for inputs, labels in test_dataloader: + for i, (inputs,img_path, labels) in enumerate(test_dataloader): inputs = inputs.to(device) labels = labels.to(device) - outputs = model(inputs) + adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path) + outputs = classifier(generated_img) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) - running_loss += loss.item() * inputs.size(0) + # running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) - epoch_loss = running_loss / len(test_dataset) + # epoch_loss = running_loss / len(test_dataset) epoch_acc = running_corrects / len(test_dataset) * 100. - print('[Test #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time)) + print('[Test #{}] Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_acc, time.time() - start_time)) - save_path = '{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset) + save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt) torch.save(model.state_dict(), save_path) diff --git a/config/config.yaml b/config/config.yaml index 5a7f645..4f39e69 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -4,6 +4,7 @@ classifier: lr: 0.01 momentum: 0.9 num_epochs: 200 + num_workers : 4 paths: @@ -12,6 +13,7 @@ paths: inverter_cfg: secret classifier: checkpoint/ stylegan: pretrained_models/stylegan2-ffhq-config-f.pt + adv_embedding: pretrained_models prompt: red lipstick # available attributes @@ -23,4 +25,7 @@ optim: batch_size: 8 num_epochs: 200 num_workers : 4 - images_resize: 256 \ No newline at end of file + images_resize: 256 + alpha: 0.1 + beta: 1 + delta: 1 \ No newline at end of file diff --git a/data/dataset.py b/data/dataset.py index fa09732..375a211 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -1,7 +1,10 @@ import torch import torchvision from torchvision import datasets, models, transforms +from torch.utils.data import Dataset +from PIL import Image import torch.nn as nn +import pathlib import os transforms_train = transforms.Compose([ @@ -17,6 +20,35 @@ transforms_test = transforms.Compose([ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) +class ImageDataset(Dataset): + def __init__(self, data_path, mode, transform=None): + self.path=data_path + data_dir=pathlib.Path(data_path) + self.mode=mode + self.transform=transform + if self.mode == 'train': + self.image_path=list(data_dir.glob("train/*/*")) + self.image_path=[str(path) for path in self.image_path] + else: + self.image_path=list(data_dir.glob("test/*/*")) + self.image_path=[str(path) for path in self.image_path] + + lable_names = sorted(item.name for item in data_dir.glob("train/*/")) + lable_to_index = dict((name, index) for index, name in enumerate(lable_names)) + self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path] + + + def __getitem__(self, index): + img = Image.open(os.path.join(self.path, self.image_path[index])) + img = img.convert('RGB') + if self.transform is not None: + img = self.transform(img) + label = torch.LongTensor([self.image_label[index]]) + image_path=self.image_path[index] + return img, image_path ,label + + def __len__(self): + return len(self.image_path) def get_dataset(config): if config.dataset == 'gender_dataset': @@ -26,6 +58,23 @@ def get_dataset(config): train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train) test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test) - train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers) - return train_dataloader,test_dataloader,train_dataset,test_dataset \ No newline at end of file + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.classifier.batch_size, shuffle=True, num_workers=config.optim.num_workers) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.classifier.batch_size, shuffle=False, num_workers=config.optim.num_workers) + return train_dataloader,test_dataloader,train_dataset,test_dataset + +def get_adv_dataset(config): + if config.dataset == 'gender_dataset': + path=config.paths.gender_dataset + else: + path=config.paths.identity_dataset + train_dataset = ImageDataset(path,'train',transforms_train) + test_dataset= ImageDataset(path,'test',transforms_test) + + train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers) + test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers) + + return train_dataloader,test_dataloader,train_dataset,test_dataset + + + + \ No newline at end of file diff --git a/model.py b/model.py index 7c0a21f..b61e262 100644 --- a/model.py +++ b/model.py @@ -2,8 +2,10 @@ import torch import torchvision from torchvision import datasets, models, transforms import torch.nn as nn +import torch.nn.functional as F import os import clip +from utils import normalize device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -45,4 +47,39 @@ class GanAttack(nn.Module): im,_=self.generator(x,input_is_latent=True, randomize_noise=False) - return img,refine_images,im,x \ No newline at end of file + return refine_images,im,x + +class CLIPLoss(torch.nn.Module): + def __init__(self): + super(CLIPLoss, self).__init__() + self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") + self.model.eval() + self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224)) + # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1) + # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1) + + def forward(self, image, text): + image=normalize(image) + image = self.face_pool(image) + similarity = 1 - self.model(image, text)[0]/ 100 + return similarity + + +class VggLoss(torch.nn.Module): + def __init__(self): + super(VggLoss, self).__init__() + self.model=models.vgg11(pretrained=True) + self.model.features=nn.Sequential() + + # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1) + # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1) + + def forward(self, image1, image2): + # image=normalize(image) + with torch.no_grad: + feature1=self.model(image1) + feature2=self.model(image2) + feature1=torch.flatten(feature1) + feature2=torch.flatten(feature2) + similarity = F.cosine_similarity(feature1,feature2) + return similarity \ No newline at end of file