more change

This commit is contained in:
leewlving 2024-01-13 19:29:24 +08:00
parent 98d940cbf9
commit 7f3ffa4807
3 changed files with 35 additions and 28 deletions

View File

@ -7,7 +7,8 @@ from torchvision import models
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from data.dataset import get_dataset,get_adv_dataset from data.dataset import get_dataset,get_adv_dataset
from utils import get_model,set_requires_grad,unnormalize from utils import get_model,set_requires_grad,unnormalize
from model import GanAttack,CLIPLoss,VggLoss from model import GanAttack,CLIPLoss
import lpips
from prompt import get_prompt from prompt import get_prompt
import torch.nn.functional as F import torch.nn.functional as F
import time import time
@ -45,6 +46,7 @@ def main(cfg: DictConfig) -> None:
prompt=get_prompt(cfg) prompt=get_prompt(cfg)
# net=get_stylegan_inverter(cfg) # net=get_stylegan_inverter(cfg)
model=GanAttack(cfg,prompt).to(device) model=GanAttack(cfg,prompt).to(device)
@ -53,6 +55,7 @@ def main(cfg: DictConfig) -> None:
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
max_loss=nn.MarginRankingLoss(0.1) max_loss=nn.MarginRankingLoss(0.1)
clip_loss=CLIPLoss().to(device) clip_loss=CLIPLoss().to(device)
loss_fn_vgg = lpips.LPIPS(net='vgg').to(device)
# vgg_loss=VggLoss().to(device) # vgg_loss=VggLoss().to(device)
# summary(model, input_size = (3, 256, 256), batch_size = 5) # summary(model, input_size = (3, 256, 256), batch_size = 5)
# set_requires_grad(model.mlp.parameters()) # set_requires_grad(model.mlp.parameters())
@ -73,12 +76,13 @@ def main(cfg: DictConfig) -> None:
# base_code=base_code.to(device) # base_code=base_code.to(device)
detail_code=detail_code.to(device) detail_code=detail_code.to(device)
# codes = model.net.encoder(inputs) # codes = model.net.encoder(inputs)
generated_img,adv_latent_codes=model(inputs,base_code,detail_code) generated_img,adv_latent_codes=model(inputs,detail_code)
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path) # _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
optimizer.zero_grad() optimizer.zero_grad()
# generated_img,adv_latent_codes=model(inputs) # generated_img,adv_latent_codes=model(inputs)
# loss_vgg=vgg_loss(inputs,generated_img) # loss_vgg=vgg_loss(inputs,generated_img)
loss_l1=F.l1_loss(base_code,adv_latent_codes) # loss_l1=F.l1_loss(base_code,adv_latent_codes)
loss_vgg=loss_fn_vgg(inputs,generated_img)
loss_clip=clip_loss(generated_img,prompt) loss_clip=clip_loss(generated_img,prompt)
adv_outputs = classifier(generated_img) adv_outputs = classifier(generated_img)
clean_outputs=classifier(inputs) clean_outputs=classifier(inputs)
@ -90,7 +94,7 @@ def main(cfg: DictConfig) -> None:
# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier # loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier # loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
loss=loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -108,13 +112,13 @@ def main(cfg: DictConfig) -> None:
running_loss = 0. #test_dataloader running_loss = 0. #test_dataloader
running_corrects = 0 running_corrects = 0
for i, (inputs, labels,base_code,detail_code) in enumerate(test_dataloader): for i, (inputs, labels,detail_code) in enumerate(test_dataloader):
inputs = inputs.to(device) inputs = inputs.to(device)
labels = labels.to(device) labels = labels.to(device)
base_code=base_code.to(device) # base_code=base_code.to(device)
detail_code=detail_code.to(device) detail_code=detail_code.to(device)
generated_img,adv_latent_codes=model(inputs,base_code,detail_code) generated_img,adv_latent_codes=model(inputs,detail_code)
outputs = classifier(generated_img) outputs = classifier(generated_img)
preds=criterion(outputs ,labels) preds=criterion(outputs ,labels)

View File

@ -77,11 +77,11 @@ class GanAttack(nn.Module):
def forward(self, img,basecode,detailcode): def forward(self, img,detailcode):
batch_size=img.shape[0] batch_size=img.shape[0]
prompt=self.prompt.repeat(batch_size,18,1).to(device) prompt=self.prompt.repeat(batch_size,18,1).to(device)
# print(prompt.shape) # print(prompt.shape)
x_prompt=torch.cat([basecode,prompt],dim=2) x_prompt=torch.cat([detailcode,prompt],dim=2)
x_prompt=self.mlp(x_prompt) x_prompt=self.mlp(x_prompt)
x=x_prompt+x x=x_prompt+x
result_images=self.generator.synthesis(detailcode,randomize_noise=False, result_images=self.generator.synthesis(detailcode,randomize_noise=False,
@ -106,24 +106,24 @@ class CLIPLoss(torch.nn.Module):
return similarity return similarity
class VggLoss(torch.nn.Module): # class VggLoss(torch.nn.Module):
def __init__(self): # def __init__(self):
super(VggLoss, self).__init__() # super(VggLoss, self).__init__()
self.model=models.vgg11(pretrained=True) # self.model=models.vgg11(pretrained=True)
self.model.features=nn.Sequential() # self.model.features=nn.Sequential()
# self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1) # # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
# self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1) # # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
def forward(self, image1, image2): # def forward(self, image1, image2):
# image=normalize(image) # # image=normalize(image)
with torch.no_grad: # with torch.no_grad:
feature1=self.model(image1) # feature1=self.model(image1)
feature2=self.model(image2) # feature2=self.model(image2)
feature1=torch.flatten(feature1) # feature1=torch.flatten(feature1)
feature2=torch.flatten(feature2) # feature2=torch.flatten(feature2)
similarity = F.cosine_similarity(feature1,feature2) # similarity = F.cosine_similarity(feature1,feature2)
return similarity # return similarity
@hydra.main(version_base=None, config_path="./config", config_name="config") @hydra.main(version_base=None, config_path="./config", config_name="config")
def test(cfg): def test(cfg):

View File

@ -26,9 +26,12 @@ def save_prompt(cfg):
spio.savemat("prompt.mat",{'prompt':prompt}) spio.savemat("prompt.mat",{'prompt':prompt})
def get_prompt(cfg): def get_prompt(cfg):
data=spio.loadmat("prompt.mat") model, preprocess = clip.load("RN50", device=device)
prompt=data['prompt'] prompt=cfg.prompt
return torch.from_numpy(prompt).to(device) text=clip.tokenize(prompt).to(device)
with torch.no_grad():
prompt = model.encode_text(text)
return prompt