GanAttack/targetedAttack.py

148 lines
5.1 KiB
Python

import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from omegaconf import DictConfig, OmegaConf
from data.dataset import get_dataset,get_adv_dataset
from utils import get_model,set_requires_grad,unnormalize
from stylegan.model import Generator
from model import CLIPLoss,TargetedGanAttack
import lpips
from prompt import get_prompt
import torch.nn.functional as F
import time
import hydra
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_stylegan_generator(cfg):
generator=Generator(1024, 512, 8)
checkpoint = torch.load(cfg.paths.stylegan, map_location=device)
generator.load_state_dict(checkpoint['g_ema'])
generator.to(device)
generator.eval()
return generator
@hydra.main(version_base=None, config_path="./config", config_name="config")
def main(cfg: DictConfig) -> None:
model=get_model(cfg)
train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg)
classifier=get_model(cfg)
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
classifier.eval()
g_ema, _=get_stylegan_generator(cfg)
prompt=get_prompt(cfg)
# net=get_stylegan_inverter(cfg)
model=TargetedGanAttack(cfg,prompt).to(device)
num_epochs = cfg.optim.num_epochs
criterion = nn.CrossEntropyLoss()
max_loss=nn.MarginRankingLoss(0.1)
clip_loss=CLIPLoss().to(device)
loss_fn_vgg = lpips.LPIPS(net='vgg').to(device)
# vgg_loss=VggLoss().to(device)
# summary(model, input_size = (3, 256, 256), batch_size = 5)
# set_requires_grad(model.mlp.parameters())
# for p in (model.mlp.parameters()):
# p.requires_grad =True
# for p in (model.noise_mlp.parameters()):
# p.requires_grad =True
model.train()
optimizer = optim.SGD(model.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
start_time = time.time()
for epoch in range(num_epochs):
model.train()
running_loss = 0
running_corrects = 0
for i, (inputs, labels,detail_code) in enumerate(train_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
# base_code=base_code.to(device)
detail_code=detail_code.to(device)
# codes = model.net.encoder(inputs)
generated_img,adv_latent_codes=model(inputs,detail_code)
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
optimizer.zero_grad()
# generated_img,adv_latent_codes=model(inputs)
# loss_vgg=vgg_loss(inputs,generated_img)
# loss_l1=F.l1_loss(base_code,adv_latent_codes)
loss_vgg=loss_fn_vgg(inputs,generated_img)
loss_clip=clip_loss(generated_img,prompt)
adv_outputs = classifier(generated_img)
clean_outputs=classifier(inputs)
adv_preds=criterion(adv_outputs,labels)
clean_preds=criterion(clean_outputs,labels)
loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds))
# _, preds = torch.max(classifier(generated_img), 1)
# loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(train_dataset)
epoch_acc = running_corrects / len(train_dataset) * 100.
print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
model.eval()
print('Evaluating!')
with torch.no_grad():
running_loss = 0. #test_dataloader
running_corrects = 0
for i, (inputs, labels,detail_code) in enumerate(test_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
# base_code=base_code.to(device)
detail_code=detail_code.to(device)
generated_img,adv_latent_codes=model(inputs,detail_code)
outputs = classifier(generated_img)
preds=criterion(outputs ,labels)
# running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
# epoch_loss = running_loss / len(test_dataset)
epoch_acc = running_corrects / len(test_dataset) * 100.
print('[Test ] Acc: {:.4f}% Time: {:.4f}s'.format( epoch_acc, time.time() - start_time))
save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt)
torch.save(model.state_dict(), save_path)
if __name__ == "__main__":
main()