From 7f3ffa480720028e19d334322e73eae49e8b53a7 Mon Sep 17 00:00:00 2001 From: leewlving Date: Sat, 13 Jan 2024 19:29:24 +0800 Subject: [PATCH] more change --- GanAttack.py | 18 +++++++++++------- model.py | 36 ++++++++++++++++++------------------ prompt.py | 9 ++++++--- 3 files changed, 35 insertions(+), 28 deletions(-) diff --git a/GanAttack.py b/GanAttack.py index f51df0f..b86dd42 100644 --- a/GanAttack.py +++ b/GanAttack.py @@ -7,7 +7,8 @@ from torchvision import models from omegaconf import DictConfig, OmegaConf from data.dataset import get_dataset,get_adv_dataset from utils import get_model,set_requires_grad,unnormalize -from model import GanAttack,CLIPLoss,VggLoss +from model import GanAttack,CLIPLoss +import lpips from prompt import get_prompt import torch.nn.functional as F import time @@ -45,6 +46,7 @@ def main(cfg: DictConfig) -> None: prompt=get_prompt(cfg) # net=get_stylegan_inverter(cfg) + model=GanAttack(cfg,prompt).to(device) @@ -53,6 +55,7 @@ def main(cfg: DictConfig) -> None: criterion = nn.CrossEntropyLoss() max_loss=nn.MarginRankingLoss(0.1) clip_loss=CLIPLoss().to(device) + loss_fn_vgg = lpips.LPIPS(net='vgg').to(device) # vgg_loss=VggLoss().to(device) # summary(model, input_size = (3, 256, 256), batch_size = 5) # set_requires_grad(model.mlp.parameters()) @@ -73,12 +76,13 @@ def main(cfg: DictConfig) -> None: # base_code=base_code.to(device) detail_code=detail_code.to(device) # codes = model.net.encoder(inputs) - generated_img,adv_latent_codes=model(inputs,base_code,detail_code) + generated_img,adv_latent_codes=model(inputs,detail_code) # _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path) optimizer.zero_grad() # generated_img,adv_latent_codes=model(inputs) # loss_vgg=vgg_loss(inputs,generated_img) - loss_l1=F.l1_loss(base_code,adv_latent_codes) + # loss_l1=F.l1_loss(base_code,adv_latent_codes) + loss_vgg=loss_fn_vgg(inputs,generated_img) loss_clip=clip_loss(generated_img,prompt) adv_outputs = classifier(generated_img) clean_outputs=classifier(inputs) @@ -90,7 +94,7 @@ def main(cfg: DictConfig) -> None: # loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier # loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier - loss=loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier + loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier loss.backward() optimizer.step() @@ -108,13 +112,13 @@ def main(cfg: DictConfig) -> None: running_loss = 0. #test_dataloader running_corrects = 0 - for i, (inputs, labels,base_code,detail_code) in enumerate(test_dataloader): + for i, (inputs, labels,detail_code) in enumerate(test_dataloader): inputs = inputs.to(device) labels = labels.to(device) - base_code=base_code.to(device) + # base_code=base_code.to(device) detail_code=detail_code.to(device) - generated_img,adv_latent_codes=model(inputs,base_code,detail_code) + generated_img,adv_latent_codes=model(inputs,detail_code) outputs = classifier(generated_img) preds=criterion(outputs ,labels) diff --git a/model.py b/model.py index 02ae0f8..2149fd9 100644 --- a/model.py +++ b/model.py @@ -77,11 +77,11 @@ class GanAttack(nn.Module): - def forward(self, img,basecode,detailcode): + def forward(self, img,detailcode): batch_size=img.shape[0] prompt=self.prompt.repeat(batch_size,18,1).to(device) # print(prompt.shape) - x_prompt=torch.cat([basecode,prompt],dim=2) + x_prompt=torch.cat([detailcode,prompt],dim=2) x_prompt=self.mlp(x_prompt) x=x_prompt+x result_images=self.generator.synthesis(detailcode,randomize_noise=False, @@ -106,24 +106,24 @@ class CLIPLoss(torch.nn.Module): return similarity -class VggLoss(torch.nn.Module): - def __init__(self): - super(VggLoss, self).__init__() - self.model=models.vgg11(pretrained=True) - self.model.features=nn.Sequential() +# class VggLoss(torch.nn.Module): +# def __init__(self): +# super(VggLoss, self).__init__() +# self.model=models.vgg11(pretrained=True) +# self.model.features=nn.Sequential() - # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1) - # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1) +# # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1) +# # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1) - def forward(self, image1, image2): - # image=normalize(image) - with torch.no_grad: - feature1=self.model(image1) - feature2=self.model(image2) - feature1=torch.flatten(feature1) - feature2=torch.flatten(feature2) - similarity = F.cosine_similarity(feature1,feature2) - return similarity +# def forward(self, image1, image2): +# # image=normalize(image) +# with torch.no_grad: +# feature1=self.model(image1) +# feature2=self.model(image2) +# feature1=torch.flatten(feature1) +# feature2=torch.flatten(feature2) +# similarity = F.cosine_similarity(feature1,feature2) +# return similarity @hydra.main(version_base=None, config_path="./config", config_name="config") def test(cfg): diff --git a/prompt.py b/prompt.py index f69e1a5..2d7b164 100644 --- a/prompt.py +++ b/prompt.py @@ -26,9 +26,12 @@ def save_prompt(cfg): spio.savemat("prompt.mat",{'prompt':prompt}) def get_prompt(cfg): - data=spio.loadmat("prompt.mat") - prompt=data['prompt'] - return torch.from_numpy(prompt).to(device) + model, preprocess = clip.load("RN50", device=device) + prompt=cfg.prompt + text=clip.tokenize(prompt).to(device) + with torch.no_grad(): + prompt = model.encode_text(text) + return prompt