GanAttack/GanAttack.py

150 lines
5.3 KiB
Python

import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from omegaconf import DictConfig, OmegaConf
from data.dataset import get_dataset,get_adv_dataset
from utils import get_model,set_requires_grad,unnormalize
# from GanInverter.inference.two_stage_inference import TwoStageInference
# from GanInverter.models.stylegan2.model import Generator
# from models import GanAttack
# from pixel2style2pixel.scripts.align_all_parallel import align_face
# from pixel2style2pixel.models.stylegan2.model import Generator
from model import GanAttack,CLIPLoss,VggLoss
from prompt import get_prompt
import torch.nn.functional as F
import time
# from torchsummary import summary
import hydra
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# def get_stylegan_generator(cfg):
# # ensure_checkpoint_exists(ckpt_path)
# ckpt_path=cfg.paths.stylegan
# g_ema = Generator(1024, 512, 8)
# g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
# g_ema.eval()
# g_ema = g_ema.cuda()
# mean_latent = g_ema.mean_latent(4096)
# return g_ema, mean_latent
@hydra.main(version_base=None, config_path="./config", config_name="config")
def main(cfg: DictConfig) -> None:
model=get_model(cfg)
train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg)
classifier=get_model(cfg)
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
classifier.eval()
# g_ema, _=get_stylegan_generator(cfg)
prompt=get_prompt(cfg)
# net=get_stylegan_inverter(cfg)
model=GanAttack(cfg,prompt).to(device)
num_epochs = cfg.optim.num_epochs
criterion = nn.CrossEntropyLoss()
max_loss=nn.MarginRankingLoss(0.1)
clip_loss=CLIPLoss().to(device)
# vgg_loss=VggLoss().to(device)
# summary(model, input_size = (3, 256, 256), batch_size = 5)
# set_requires_grad(model.mlp.parameters())
for p in (model.mlp.parameters()):
p.requires_grad =True
optimizer = optim.SGD(model.mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
start_time = time.time()
for epoch in range(num_epochs):
model.train()
running_loss = 0
running_corrects = 0
for i, (inputs, labels,base_code,detail_code) in enumerate(train_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
base_code=base_code.to(device)
detail_code=detail_code.to(device)
# codes = model.net.encoder(inputs)
generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
optimizer.zero_grad()
# generated_img,adv_latent_codes=model(inputs)
# loss_vgg=vgg_loss(inputs,generated_img)
loss_l1=F.l1_loss(base_code,adv_latent_codes)
loss_clip=clip_loss(generated_img,prompt)
adv_outputs = classifier(generated_img)
clean_outputs=classifier(inputs)
adv_preds=criterion(adv_outputs,labels)
clean_preds=criterion(clean_outputs,labels)
loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds))
# _, preds = torch.max(classifier(generated_img), 1)
# loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
loss=loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(train_dataset)
epoch_acc = running_corrects / len(train_dataset) * 100.
print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
model.eval()
print('Evaluating!')
with torch.no_grad():
running_loss = 0. #test_dataloader
running_corrects = 0
for i, (inputs, labels,base_code,detail_code) in enumerate(test_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
base_code=base_code.to(device)
detail_code=detail_code.to(device)
generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
outputs = classifier(generated_img)
preds=criterion(outputs ,labels)
# running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
# epoch_loss = running_loss / len(test_dataset)
epoch_acc = running_corrects / len(test_dataset) * 100.
print('[Test ] Acc: {:.4f}% Time: {:.4f}s'.format( epoch_acc, time.time() - start_time))
save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt)
torch.save(model.state_dict(), save_path)
if __name__ == "__main__":
main()