import sys import os import torch import torch.nn as nn import torch.optim as optim from torchvision import models from omegaconf import DictConfig, OmegaConf from data.dataset import get_dataset,get_adv_dataset from utils import get_model,set_requires_grad,unnormalize from model import GanAttack,CLIPLoss import lpips from prompt import get_prompt import torch.nn.functional as F import time import hydra device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # def get_stylegan_generator(cfg): # # ensure_checkpoint_exists(ckpt_path) # ckpt_path=cfg.paths.stylegan # g_ema = Generator(1024, 512, 8) # g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False) # g_ema.eval() # g_ema = g_ema.cuda() # mean_latent = g_ema.mean_latent(4096) # return g_ema, mean_latent @hydra.main(version_base=None, config_path="./config", config_name="config") def main(cfg: DictConfig) -> None: model=get_model(cfg) train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg) classifier=get_model(cfg) classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset))) classifier.eval() # g_ema, _=get_stylegan_generator(cfg) prompt=get_prompt(cfg) # net=get_stylegan_inverter(cfg) model=GanAttack(cfg,prompt).to(device) num_epochs = cfg.optim.num_epochs criterion = nn.CrossEntropyLoss() max_loss=nn.MarginRankingLoss(0.1) clip_loss=CLIPLoss().to(device) loss_fn_vgg = lpips.LPIPS(net='vgg').to(device) # vgg_loss=VggLoss().to(device) # summary(model, input_size = (3, 256, 256), batch_size = 5) # set_requires_grad(model.mlp.parameters()) for p in (model.mlp.parameters()): p.requires_grad =True for p in (model.noise_mlp.parameters()): p.requires_grad =True optimizer = optim.SGD(model.mlp.parameters()+model.noise_mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum) start_time = time.time() for epoch in range(num_epochs): model.train() running_loss = 0 running_corrects = 0 for i, (inputs, labels,detail_code) in enumerate(train_dataloader): inputs = inputs.to(device) labels = labels.to(device) # base_code=base_code.to(device) detail_code=detail_code.to(device) # codes = model.net.encoder(inputs) generated_img,adv_latent_codes=model(inputs,detail_code) # _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path) optimizer.zero_grad() # generated_img,adv_latent_codes=model(inputs) # loss_vgg=vgg_loss(inputs,generated_img) # loss_l1=F.l1_loss(base_code,adv_latent_codes) loss_vgg=loss_fn_vgg(inputs,generated_img) loss_clip=clip_loss(generated_img,prompt) adv_outputs = classifier(generated_img) clean_outputs=classifier(inputs) adv_preds=criterion(adv_outputs,labels) clean_preds=criterion(clean_outputs,labels) loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds)) # _, preds = torch.max(classifier(generated_img), 1) # loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels)) # loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier # loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(train_dataset) epoch_acc = running_corrects / len(train_dataset) * 100. print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time)) model.eval() print('Evaluating!') with torch.no_grad(): running_loss = 0. #test_dataloader running_corrects = 0 for i, (inputs, labels,detail_code) in enumerate(test_dataloader): inputs = inputs.to(device) labels = labels.to(device) # base_code=base_code.to(device) detail_code=detail_code.to(device) generated_img,adv_latent_codes=model(inputs,detail_code) outputs = classifier(generated_img) preds=criterion(outputs ,labels) # running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) # epoch_loss = running_loss / len(test_dataset) epoch_acc = running_corrects / len(test_dataset) * 100. print('[Test ] Acc: {:.4f}% Time: {:.4f}s'.format( epoch_acc, time.time() - start_time)) save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt) torch.save(model.state_dict(), save_path) if __name__ == "__main__": main()