GanAttack/GanAttack.py

153 lines
5.0 KiB
Python

import sys
import os
sys.path.append('./pixel2style2pixel')
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from omegaconf import DictConfig, OmegaConf
from data.dataset import get_dataset,get_adv_dataset
from utils import get_model,set_requires_grad,unnormalize
# from GanInverter.inference.two_stage_inference import TwoStageInference
# from GanInverter.models.stylegan2.model import Generator
# from models import GanAttack
from argparse import Namespace
# from pixel2style2pixel.scripts.align_all_parallel import align_face
from pixel2style2pixel.models.psp import pSp
# from pixel2style2pixel.models.stylegan2.model import Generator
from model import GanAttack,CLIPLoss,VggLoss,get_prompt
import torch.nn.functional as F
import time
import hydra
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# def get_stylegan_generator(cfg):
# # ensure_checkpoint_exists(ckpt_path)
# ckpt_path=cfg.paths.stylegan
# g_ema = Generator(1024, 512, 8)
# g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
# g_ema.eval()
# g_ema = g_ema.cuda()
# mean_latent = g_ema.mean_latent(4096)
# return g_ema, mean_latent
def get_stylegan_inverter(cfg):
# ensure_checkpoint_exists(ckpt_path)
path=cfg.paths.inverter_cfg
ckpt = torch.load(path, map_location='cuda:0')
opts = ckpt['opts']
opts['checkpoint_path'] = path
if 'learn_in_w' not in opts:
opts['learn_in_w'] = False
if 'output_size' not in opts:
opts['output_size'] = 1024
net = pSp(Namespace(**opts))
net.eval()
net.cuda()
return net
@hydra.main(version_base=None, config_path="./config", config_name="config")
def main(cfg: DictConfig) -> None:
model=get_model(cfg)
train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg)
classifier=get_model(cfg)
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
classifier.eval()
# g_ema, _=get_stylegan_generator(cfg)
prompt=get_prompt(cfg.prompt)
net=get_stylegan_inverter(cfg)
model=GanAttack(net,prompt).to(device)
num_epochs = cfg.optim.num_epochs
criterion = nn.CrossEntropyLoss()
max_loss=nn.MarginRankingLoss(0.1)
clip_loss=CLIPLoss().to(device)
vgg_loss=VggLoss().to(device)
set_requires_grad(model.mlp.parameters())
optimizer = optim.SGD(model.mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
start_time = time.time()
for epoch in range(num_epochs):
model.train()
running_loss = 0
running_corrects = 0
for i, (inputs, labels) in enumerate(train_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
codes = net.encoder(inputs)
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
optimizer.zero_grad()
generated_img,adv_latent_codes=model(inputs)
loss_vgg=vgg_loss(inputs,generated_img)
loss_l1=F.l1_loss(codes,adv_latent_codes)
loss_clip=clip_loss(generated_img,prompt)
_, preds = torch.max(classifier(generated_img), 1)
loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(train_dataset)
epoch_acc = running_corrects / len(train_dataset) * 100.
print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
model.eval()
print('Evaluating!')
with torch.no_grad():
running_loss = 0. #test_dataloader
running_corrects = 0
for i, (inputs, labels) in enumerate(test_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
generated_img,adv_latent_codes=model(inputs)
outputs = classifier(generated_img)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
# epoch_loss = running_loss / len(test_dataset)
epoch_acc = running_corrects / len(test_dataset) * 100.
print('[Test #{}] Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_acc, time.time() - start_time))
save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt)
torch.save(model.state_dict(), save_path)
if __name__ == "__main__":
main()