GanAttack/untargetedAttack.py

142 lines
4.4 KiB
Python

import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from omegaconf import DictConfig, OmegaConf
from data.dataset import get_dataset,get_adv_dataset
from utils import get_model,set_requires_grad,unnormalize
from model import UnTargetedGanAttack,CLIPLoss
import lpips
from prompt import get_prompt
import torch.nn.functional as F
import time
import hydra
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# def get_stylegan_generator(cfg):
# # ensure_checkpoint_exists(ckpt_path)
# ckpt_path=cfg.paths.stylegan
# g_ema = Generator(1024, 512, 8)
# g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
# g_ema.eval()
# g_ema = g_ema.cuda()
# mean_latent = g_ema.mean_latent(4096)
# return g_ema, mean_latent
def cw_loss(outputs, labels):
one_hot_labels = torch.eye(outputs.shape[1]).to(device)[labels]
# find the max logit other than the target class
other = torch.max((1 - one_hot_labels) * outputs, dim=1)[0]
# get the target class's logit
real = torch.max(one_hot_labels * outputs, dim=1)[0]
return torch.clamp((real - other), min=-0.)
@hydra.main(version_base=None, config_path="./config", config_name="config")
def main(cfg: DictConfig) -> None:
model=get_model(cfg)
train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg)
classifier=get_model(cfg)
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
classifier.eval()
# g_ema, _=get_stylegan_generator(cfg)
prompt=get_prompt(cfg)
# net=get_stylegan_inverter(cfg)
model=UnTargetedGanAttack(cfg,prompt).to(device)
num_epochs = cfg.optim.num_epochs
criterion = nn.CrossEntropyLoss()
clip_loss=CLIPLoss().to(device)
loss_fn_vgg = lpips.LPIPS(net='vgg').to(device)
optimizer = optim.SGD(model.mlp.parameters()+model.noise_mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
start_time = time.time()
for epoch in range(num_epochs):
model.train()
running_loss = 0
running_corrects = 0
for i, (inputs, labels,detail_code) in enumerate(train_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
# base_code=base_code.to(device)
detail_code=detail_code.to(device)
# codes = model.net.encoder(inputs)
generated_img,adv_latent_codes=model(inputs,detail_code)
optimizer.zero_grad()
loss_vgg=loss_fn_vgg(inputs,generated_img)
loss_clip=clip_loss(generated_img,prompt)
adv_outputs = classifier(generated_img)
loss_classifier=cw_loss(adv_outputs,labels)
loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(train_dataset)
epoch_acc = running_corrects / len(train_dataset) * 100.
print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
model.eval()
print('Evaluating!')
with torch.no_grad():
running_loss = 0. #test_dataloader
running_corrects = 0
for i, (inputs, labels,detail_code) in enumerate(test_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
# base_code=base_code.to(device)
detail_code=detail_code.to(device)
generated_img,adv_latent_codes=model(inputs,detail_code)
outputs = classifier(generated_img)
preds=criterion(outputs ,labels)
# running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
# epoch_loss = running_loss / len(test_dataset)
epoch_acc = running_corrects / len(test_dataset) * 100.
print('[Test ] Acc: {:.4f}% Time: {:.4f}s'.format( epoch_acc, time.time() - start_time))
save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt)
torch.save(model.state_dict(), save_path)
if __name__ == "__main__":
main()