148 lines
5.1 KiB
Python
148 lines
5.1 KiB
Python
import sys
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torchvision import models
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from data.dataset import get_dataset,get_adv_dataset
|
|
from utils import get_model,set_requires_grad,unnormalize
|
|
from model import GanAttack,CLIPLoss
|
|
import lpips
|
|
from prompt import get_prompt
|
|
import torch.nn.functional as F
|
|
import time
|
|
|
|
import hydra
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
# def get_stylegan_generator(cfg):
|
|
|
|
# # ensure_checkpoint_exists(ckpt_path)
|
|
# ckpt_path=cfg.paths.stylegan
|
|
# g_ema = Generator(1024, 512, 8)
|
|
# g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
|
|
# g_ema.eval()
|
|
# g_ema = g_ema.cuda()
|
|
# mean_latent = g_ema.mean_latent(4096)
|
|
|
|
# return g_ema, mean_latent
|
|
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
|
def main(cfg: DictConfig) -> None:
|
|
model=get_model(cfg)
|
|
train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg)
|
|
|
|
classifier=get_model(cfg)
|
|
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
|
|
classifier.eval()
|
|
# g_ema, _=get_stylegan_generator(cfg)
|
|
prompt=get_prompt(cfg)
|
|
|
|
# net=get_stylegan_inverter(cfg)
|
|
|
|
model=GanAttack(cfg,prompt).to(device)
|
|
|
|
|
|
|
|
num_epochs = cfg.optim.num_epochs
|
|
criterion = nn.CrossEntropyLoss()
|
|
max_loss=nn.MarginRankingLoss(0.1)
|
|
clip_loss=CLIPLoss().to(device)
|
|
loss_fn_vgg = lpips.LPIPS(net='vgg').to(device)
|
|
# vgg_loss=VggLoss().to(device)
|
|
# summary(model, input_size = (3, 256, 256), batch_size = 5)
|
|
# set_requires_grad(model.mlp.parameters())
|
|
for p in (model.mlp.parameters()):
|
|
p.requires_grad =True
|
|
optimizer = optim.SGD(model.mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
|
|
start_time = time.time()
|
|
|
|
for epoch in range(num_epochs):
|
|
model.train()
|
|
|
|
running_loss = 0
|
|
running_corrects = 0
|
|
|
|
for i, (inputs, labels,detail_code) in enumerate(train_dataloader):
|
|
inputs = inputs.to(device)
|
|
labels = labels.to(device)
|
|
# base_code=base_code.to(device)
|
|
detail_code=detail_code.to(device)
|
|
# codes = model.net.encoder(inputs)
|
|
generated_img,adv_latent_codes=model(inputs,detail_code)
|
|
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
|
optimizer.zero_grad()
|
|
# generated_img,adv_latent_codes=model(inputs)
|
|
# loss_vgg=vgg_loss(inputs,generated_img)
|
|
# loss_l1=F.l1_loss(base_code,adv_latent_codes)
|
|
loss_vgg=loss_fn_vgg(inputs,generated_img)
|
|
loss_clip=clip_loss(generated_img,prompt)
|
|
adv_outputs = classifier(generated_img)
|
|
clean_outputs=classifier(inputs)
|
|
adv_preds=criterion(adv_outputs,labels)
|
|
clean_preds=criterion(clean_outputs,labels)
|
|
loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds))
|
|
# _, preds = torch.max(classifier(generated_img), 1)
|
|
# loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
|
|
|
|
# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
|
# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
|
loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
running_loss += loss.item() * inputs.size(0)
|
|
running_corrects += torch.sum(preds == labels.data)
|
|
|
|
epoch_loss = running_loss / len(train_dataset)
|
|
epoch_acc = running_corrects / len(train_dataset) * 100.
|
|
print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
|
|
|
|
|
|
model.eval()
|
|
print('Evaluating!')
|
|
with torch.no_grad():
|
|
running_loss = 0. #test_dataloader
|
|
running_corrects = 0
|
|
|
|
for i, (inputs, labels,detail_code) in enumerate(test_dataloader):
|
|
inputs = inputs.to(device)
|
|
labels = labels.to(device)
|
|
# base_code=base_code.to(device)
|
|
detail_code=detail_code.to(device)
|
|
|
|
generated_img,adv_latent_codes=model(inputs,detail_code)
|
|
outputs = classifier(generated_img)
|
|
preds=criterion(outputs ,labels)
|
|
|
|
# running_loss += loss.item() * inputs.size(0)
|
|
running_corrects += torch.sum(preds == labels.data)
|
|
|
|
# epoch_loss = running_loss / len(test_dataset)
|
|
epoch_acc = running_corrects / len(test_dataset) * 100.
|
|
print('[Test ] Acc: {:.4f}% Time: {:.4f}s'.format( epoch_acc, time.time() - start_time))
|
|
|
|
save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt)
|
|
torch.save(model.state_dict(), save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |