add others

This commit is contained in:
Li Wenyun 2023-12-03 17:50:47 +08:00
parent 4f6194e2af
commit dfd67482a1
4 changed files with 129 additions and 24 deletions

View File

@ -3,11 +3,12 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torchvision import models from torchvision import models
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from data.dataset import get_dataset from data.dataset import get_dataset,get_adv_dataset
from utils import get_model from utils import get_model,set_requires_grad
from model.GanInverter.models.stylegan2.model import Generator from model.GanInverter.models.stylegan2.model import Generator
from model.GanInverter.inference.two_stage_inference import TwoStageInference from model.GanInverter.inference.two_stage_inference import TwoStageInference
from model import GanAttack from model import GanAttack,CLIPLoss,VggLoss,get_prompt
import torch.nn.functional as F
import time import time
import hydra import hydra
@ -38,17 +39,25 @@ def get_stylegan_inverter(cfg):
@hydra.main(version_base=None, config_path="./config", config_name="config") @hydra.main(version_base=None, config_path="./config", config_name="config")
def main(cfg: DictConfig) -> None: def main(cfg: DictConfig) -> None:
model=get_model(cfg) model=get_model(cfg)
train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg) train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg)
classifier=get_model(cfg) classifier=get_model(cfg)
classifier.eval()
g_ema, _=get_stylegan_generator(cfg) g_ema, _=get_stylegan_generator(cfg)
inverter=get_stylegan_inverter(cfg) inverter=get_stylegan_inverter(cfg)
model=GanAttack(g_ema,inverter,images_resize=cfg.optim.images_resize,prompt=cfg.prompt).to(device) model=GanAttack(g_ema,inverter,images_resize=cfg.optim.images_resize,prompt=cfg.prompt).to(device)
prompt=get_prompt(cfg)
num_epochs = cfg.optim.num_epochs num_epochs = cfg.optim.num_epochs
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum) max_loss=nn.MarginRankingLoss(0.1)
clip_loss=CLIPLoss().to(device)
vgg_loss=VggLoss().to(device)
set_requires_grad(model.mlp.parameters())
optimizer = optim.SGD(model.mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
start_time = time.time() start_time = time.time()
for epoch in range(num_epochs): for epoch in range(num_epochs):
@ -57,15 +66,19 @@ def main(cfg: DictConfig) -> None:
running_loss = 0 running_loss = 0
running_corrects = 0 running_corrects = 0
for i, (inputs, labels) in enumerate(train_dataloader): for i, (inputs,img_path, labels) in enumerate(train_dataloader):
inputs = inputs.to(device) inputs = inputs.to(device)
labels = labels.to(device) labels = labels.to(device)
_, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
optimizer.zero_grad() optimizer.zero_grad()
outputs = model(inputs) adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path)
_, preds = torch.max(outputs, 1) loss_vgg=vgg_loss(inputs,generated_img)
loss_l1=F.l1_loss(clean_latent_codes,adv_latent_codes)
loss = criterion(outputs, labels) loss_clip=clip_loss(generated_img,prompt)
_, preds = torch.max(classifier(generated_img), 1)
loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -78,27 +91,28 @@ def main(cfg: DictConfig) -> None:
model.eval() model.eval()
print('Evaluating!')
with torch.no_grad(): with torch.no_grad():
running_loss = 0. running_loss = 0. #test_dataloader
running_corrects = 0 running_corrects = 0
for inputs, labels in test_dataloader: for i, (inputs,img_path, labels) in enumerate(test_dataloader):
inputs = inputs.to(device) inputs = inputs.to(device)
labels = labels.to(device) labels = labels.to(device)
outputs = model(inputs) adv_refine_images,generated_img,adv_latent_codes=model(inputs,img_path)
outputs = classifier(generated_img)
_, preds = torch.max(outputs, 1) _, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels) loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0) # running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data) running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(test_dataset) # epoch_loss = running_loss / len(test_dataset)
epoch_acc = running_corrects / len(test_dataset) * 100. epoch_acc = running_corrects / len(test_dataset) * 100.
print('[Test #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time)) print('[Test #{}] Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_acc, time.time() - start_time))
save_path = '{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset) save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt)
torch.save(model.state_dict(), save_path) torch.save(model.state_dict(), save_path)

View File

@ -4,6 +4,7 @@ classifier:
lr: 0.01 lr: 0.01
momentum: 0.9 momentum: 0.9
num_epochs: 200 num_epochs: 200
num_workers : 4
paths: paths:
@ -12,6 +13,7 @@ paths:
inverter_cfg: secret inverter_cfg: secret
classifier: checkpoint/ classifier: checkpoint/
stylegan: pretrained_models/stylegan2-ffhq-config-f.pt stylegan: pretrained_models/stylegan2-ffhq-config-f.pt
adv_embedding: pretrained_models
prompt: red lipstick prompt: red lipstick
# available attributes # available attributes
@ -23,4 +25,7 @@ optim:
batch_size: 8 batch_size: 8
num_epochs: 200 num_epochs: 200
num_workers : 4 num_workers : 4
images_resize: 256 images_resize: 256
alpha: 0.1
beta: 1
delta: 1

View File

@ -1,7 +1,10 @@
import torch import torch
import torchvision import torchvision
from torchvision import datasets, models, transforms from torchvision import datasets, models, transforms
from torch.utils.data import Dataset
from PIL import Image
import torch.nn as nn import torch.nn as nn
import pathlib
import os import os
transforms_train = transforms.Compose([ transforms_train = transforms.Compose([
@ -17,6 +20,35 @@ transforms_test = transforms.Compose([
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]) ])
class ImageDataset(Dataset):
def __init__(self, data_path, mode, transform=None):
self.path=data_path
data_dir=pathlib.Path(data_path)
self.mode=mode
self.transform=transform
if self.mode == 'train':
self.image_path=list(data_dir.glob("train/*/*"))
self.image_path=[str(path) for path in self.image_path]
else:
self.image_path=list(data_dir.glob("test/*/*"))
self.image_path=[str(path) for path in self.image_path]
lable_names = sorted(item.name for item in data_dir.glob("train/*/"))
lable_to_index = dict((name, index) for index, name in enumerate(lable_names))
self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path]
def __getitem__(self, index):
img = Image.open(os.path.join(self.path, self.image_path[index]))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
label = torch.LongTensor([self.image_label[index]])
image_path=self.image_path[index]
return img, image_path ,label
def __len__(self):
return len(self.image_path)
def get_dataset(config): def get_dataset(config):
if config.dataset == 'gender_dataset': if config.dataset == 'gender_dataset':
@ -26,6 +58,23 @@ def get_dataset(config):
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train) train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test) test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.classifier.batch_size, shuffle=True, num_workers=config.optim.num_workers)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.classifier.batch_size, shuffle=False, num_workers=config.optim.num_workers)
return train_dataloader,test_dataloader,train_dataset,test_dataset return train_dataloader,test_dataloader,train_dataset,test_dataset
def get_adv_dataset(config):
if config.dataset == 'gender_dataset':
path=config.paths.gender_dataset
else:
path=config.paths.identity_dataset
train_dataset = ImageDataset(path,'train',transforms_train)
test_dataset= ImageDataset(path,'test',transforms_test)
train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
return train_dataloader,test_dataloader,train_dataset,test_dataset

View File

@ -2,8 +2,10 @@ import torch
import torchvision import torchvision
from torchvision import datasets, models, transforms from torchvision import datasets, models, transforms
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import os import os
import clip import clip
from utils import normalize
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@ -45,4 +47,39 @@ class GanAttack(nn.Module):
im,_=self.generator(x,input_is_latent=True, randomize_noise=False) im,_=self.generator(x,input_is_latent=True, randomize_noise=False)
return img,refine_images,im,x return refine_images,im,x
class CLIPLoss(torch.nn.Module):
def __init__(self):
super(CLIPLoss, self).__init__()
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
self.model.eval()
self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
# self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
# self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
def forward(self, image, text):
image=normalize(image)
image = self.face_pool(image)
similarity = 1 - self.model(image, text)[0]/ 100
return similarity
class VggLoss(torch.nn.Module):
def __init__(self):
super(VggLoss, self).__init__()
self.model=models.vgg11(pretrained=True)
self.model.features=nn.Sequential()
# self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
# self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
def forward(self, image1, image2):
# image=normalize(image)
with torch.no_grad:
feature1=self.model(image1)
feature2=self.model(image2)
feature1=torch.flatten(feature1)
feature2=torch.flatten(feature2)
similarity = F.cosine_similarity(feature1,feature2)
return similarity