GanAttack/targetedAttack.py

160 lines
5.4 KiB
Python

import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from omegaconf import DictConfig, OmegaConf
from data.dataset import get_dataset,get_adv_dataset
from utils import get_model,set_requires_grad,unnormalize
from stylegan.model import Generator
from model import CLIPLoss,TargetedGanAttack
import lpips
from prompt import get_prompt
import torch.nn.functional as F
import time
import hydra
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def generate_labels(labels,NUM_CLASSES=307):
targets = torch.zeros_like(labels)
for i in range(len(labels)):
rand_v = torch.randint(0, NUM_CLASSES-1)
while labels[i]==rand_v:
rand_v = torch.randint(0, NUM_CLASSES-1)
targets[i] = rand_v
# targets = targets.astype(np.int32)
return targets
def cw_loss(outputs, labels):
one_hot_labels = torch.eye(outputs.shape[1]).to(device)[labels]
# find the max logit other than the target class
other = torch.max((1 - one_hot_labels) * outputs, dim=1)[0]
# get the target class's logit
real = torch.max(one_hot_labels * outputs, dim=1)[0]
return torch.clamp((other - real), min=-0.)
def get_stylegan_generator(cfg):
generator=Generator(1024, 512, 8)
checkpoint = torch.load(cfg.paths.stylegan, map_location=device)
generator.load_state_dict(checkpoint['g_ema'])
generator.to(device)
generator.eval()
return generator
@hydra.main(version_base=None, config_path="./config", config_name="config")
def main(cfg: DictConfig) -> None:
model=get_model(cfg)
train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg)
classifier=get_model(cfg)
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
classifier.eval()
g_ema, _=get_stylegan_generator(cfg)
prompt=get_prompt(cfg)
# net=get_stylegan_inverter(cfg)
model=TargetedGanAttack(cfg,prompt).to(device)
num_epochs = cfg.optim.num_epochs
criterion = nn.CrossEntropyLoss()
max_loss=nn.MarginRankingLoss(0.1)
clip_loss=CLIPLoss().to(device)
loss_fn_vgg = lpips.LPIPS(net='vgg').to(device)
model.train()
optimizer = optim.SGD(model.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
start_time = time.time()
for epoch in range(num_epochs):
model.train()
running_loss = 0
running_corrects = 0
for i, (inputs, labels,detail_code) in enumerate(train_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
target=generate_labels(labels)
# base_code=base_code.to(device)
detail_code=detail_code.to(device)
# codes = model.net.encoder(inputs)
generated_img,adv_latent_codes=model(inputs,detail_code,target)
optimizer.zero_grad()
loss_vgg=loss_fn_vgg(inputs,generated_img)
loss_clip=clip_loss(generated_img,prompt)
adv_outputs = classifier(generated_img)
loss_classifier=cw_loss(adv_outputs,target)
# clean_outputs=classifier(inputs)
# adv_preds=criterion(adv_outputs,labels)
# clean_preds=criterion(clean_outputs,labels)
# loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds))
# _, preds = torch.max(classifier(generated_img), 1)
# loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(train_dataset)
epoch_acc = running_corrects / len(train_dataset) * 100.
print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
model.eval()
print('Evaluating!')
with torch.no_grad():
running_loss = 0. #test_dataloader
running_corrects = 0
for i, (inputs, labels,detail_code) in enumerate(test_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
# base_code=base_code.to(device)
detail_code=detail_code.to(device)
generated_img,adv_latent_codes=model(inputs,detail_code)
outputs = classifier(generated_img)
preds=criterion(outputs ,labels)
# running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
# epoch_loss = running_loss / len(test_dataset)
epoch_acc = running_corrects / len(test_dataset) * 100.
print('[Test ] Acc: {:.4f}% Time: {:.4f}s'.format( epoch_acc, time.time() - start_time))
save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt)
torch.save(model.state_dict(), save_path)
if __name__ == "__main__":
main()