import sys import os import torch import torch.nn as nn import torch.optim as optim from torchvision import models from omegaconf import DictConfig, OmegaConf from data.dataset import get_dataset,get_adv_dataset from utils import get_model,set_requires_grad,unnormalize from stylegan.model import Generator from model import CLIPLoss,TargetedGanAttack import lpips from prompt import get_prompt import torch.nn.functional as F import time import hydra device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def generate_labels(labels,NUM_CLASSES=307): targets = torch.zeros_like(labels) for i in range(len(labels)): rand_v = torch.randint(0, NUM_CLASSES-1) while labels[i]==rand_v: rand_v = torch.randint(0, NUM_CLASSES-1) targets[i] = rand_v return targets def cw_loss(outputs, labels): one_hot_labels = torch.eye(outputs.shape[1]).to(device)[labels] # find the max logit other than the target class other = torch.max((1 - one_hot_labels) * outputs, dim=1)[0] # get the target class's logit real = torch.max(one_hot_labels * outputs, dim=1)[0] return torch.clamp((other - real), min=-0.) def get_stylegan_generator(cfg): generator=Generator(1024, 512, 8) checkpoint = torch.load(cfg.paths.stylegan, map_location=device) generator.load_state_dict(checkpoint['g_ema']) generator.to(device) generator.eval() return generator @hydra.main(version_base=None, config_path="./config", config_name="config") def main(cfg: DictConfig) -> None: model=get_model(cfg) train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg) classifier=get_model(cfg) classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset))) classifier.eval() g_ema=get_stylegan_generator(cfg) prompt=get_prompt(cfg) # net=get_stylegan_inverter(cfg) model=TargetedGanAttack(cfg,prompt).to(device) num_epochs = cfg.optim.num_epochs criterion = nn.CrossEntropyLoss() max_loss=nn.MarginRankingLoss(0.1) clip_loss=CLIPLoss().to(device) loss_fn_vgg = lpips.LPIPS(net='vgg').to(device) model.train() optimizer = optim.SGD(model.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum) start_time = time.time() for epoch in range(num_epochs): model.train() running_loss = 0 running_corrects = 0 for i, (inputs, labels,detail_code) in enumerate(train_dataloader): inputs = inputs.to(device) labels = labels.to(device) target=generate_labels(labels) noise=g_ema.make_noise() # base_code=base_code.to(device) detail_code=detail_code.to(device) # codes = model.net.encoder(inputs) style, noisy=model(inputs,detail_code,target) detail=torch.clamp(style,-cfg.epsilon,cfg.spsilon)+detail_code noisy=[ torch.clamp(single, -cfg.noise_epsilon,cfg.noise_epsilon) for single in noisy ] noise[:8]=noisy generated_img=g_ema(detail,input_is_latent=True,noise=noise,randomize_noise=False) optimizer.zero_grad() loss_vgg=loss_fn_vgg(inputs,generated_img) loss_clip=clip_loss(generated_img,prompt) adv_outputs = classifier(generated_img) loss_classifier=cw_loss(adv_outputs,target) loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(train_dataset) epoch_acc = running_corrects / len(train_dataset) * 100. print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time)) model.eval() print('Evaluating!') with torch.no_grad(): running_loss = 0. #test_dataloader running_corrects = 0 for i, (inputs, labels,detail_code) in enumerate(test_dataloader): inputs = inputs.to(device) labels = labels.to(device) # base_code=base_code.to(device) detail_code=detail_code.to(device) generated_img,adv_latent_codes=model(inputs,detail_code) outputs = classifier(generated_img) preds=criterion(outputs ,labels) # running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) # epoch_loss = running_loss / len(test_dataset) epoch_acc = running_corrects / len(test_dataset) * 100. print('[Test ] Acc: {:.4f}% Time: {:.4f}s'.format( epoch_acc, time.time() - start_time)) save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt) torch.save(model.state_dict(), save_path) if __name__ == "__main__": main()