import sys import os import torch import torch.nn as nn import torch.optim as optim from torchvision import models from omegaconf import DictConfig, OmegaConf from data.dataset import get_dataset,get_adv_dataset from utils import get_model,set_requires_grad,unnormalize from model import UnTargetedGanAttack,CLIPLoss import lpips from prompt import get_prompt import torch.nn.functional as F import time import hydra device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # def get_stylegan_generator(cfg): # # ensure_checkpoint_exists(ckpt_path) # ckpt_path=cfg.paths.stylegan # g_ema = Generator(1024, 512, 8) # g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False) # g_ema.eval() # g_ema = g_ema.cuda() # mean_latent = g_ema.mean_latent(4096) # return g_ema, mean_latent def cw_loss(outputs, labels): one_hot_labels = torch.eye(outputs.shape[1]).to(device)[labels] # find the max logit other than the target class other = torch.max((1 - one_hot_labels) * outputs, dim=1)[0] # get the target class's logit real = torch.max(one_hot_labels * outputs, dim=1)[0] return torch.clamp((real - other), min=-0.) @hydra.main(version_base=None, config_path="./config", config_name="config") def main(cfg: DictConfig) -> None: model=get_model(cfg) train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg) classifier=get_model(cfg) classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset))) classifier.eval() # g_ema, _=get_stylegan_generator(cfg) prompt=get_prompt(cfg) # net=get_stylegan_inverter(cfg) model=UnTargetedGanAttack(cfg,prompt).to(device) num_epochs = cfg.optim.num_epochs criterion = nn.CrossEntropyLoss() clip_loss=CLIPLoss().to(device) loss_fn_vgg = lpips.LPIPS(net='vgg').to(device) optimizer = optim.SGD(model.mlp.parameters()+model.noise_mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum) start_time = time.time() for epoch in range(num_epochs): model.train() running_loss = 0 running_corrects = 0 for i, (inputs, labels,detail_code) in enumerate(train_dataloader): inputs = inputs.to(device) labels = labels.to(device) # base_code=base_code.to(device) detail_code=detail_code.to(device) # codes = model.net.encoder(inputs) generated_img,adv_latent_codes=model(inputs,detail_code) optimizer.zero_grad() loss_vgg=loss_fn_vgg(inputs,generated_img) loss_clip=clip_loss(generated_img,prompt) adv_outputs = classifier(generated_img) loss_classifier=cw_loss(adv_outputs,labels) loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(train_dataset) epoch_acc = running_corrects / len(train_dataset) * 100. print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time)) model.eval() print('Evaluating!') with torch.no_grad(): running_loss = 0. #test_dataloader running_corrects = 0 for i, (inputs, labels,detail_code) in enumerate(test_dataloader): inputs = inputs.to(device) labels = labels.to(device) # base_code=base_code.to(device) detail_code=detail_code.to(device) generated_img,adv_latent_codes=model(inputs,detail_code) outputs = classifier(generated_img) preds=criterion(outputs ,labels) # running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) # epoch_loss = running_loss / len(test_dataset) epoch_acc = running_corrects / len(test_dataset) * 100. print('[Test ] Acc: {:.4f}% Time: {:.4f}s'.format( epoch_acc, time.time() - start_time)) save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt) torch.save(model.state_dict(), save_path) if __name__ == "__main__": main()