more change
This commit is contained in:
parent
98d940cbf9
commit
7f3ffa4807
18
GanAttack.py
18
GanAttack.py
|
|
@ -7,7 +7,8 @@ from torchvision import models
|
|||
from omegaconf import DictConfig, OmegaConf
|
||||
from data.dataset import get_dataset,get_adv_dataset
|
||||
from utils import get_model,set_requires_grad,unnormalize
|
||||
from model import GanAttack,CLIPLoss,VggLoss
|
||||
from model import GanAttack,CLIPLoss
|
||||
import lpips
|
||||
from prompt import get_prompt
|
||||
import torch.nn.functional as F
|
||||
import time
|
||||
|
|
@ -45,6 +46,7 @@ def main(cfg: DictConfig) -> None:
|
|||
prompt=get_prompt(cfg)
|
||||
|
||||
# net=get_stylegan_inverter(cfg)
|
||||
|
||||
model=GanAttack(cfg,prompt).to(device)
|
||||
|
||||
|
||||
|
|
@ -53,6 +55,7 @@ def main(cfg: DictConfig) -> None:
|
|||
criterion = nn.CrossEntropyLoss()
|
||||
max_loss=nn.MarginRankingLoss(0.1)
|
||||
clip_loss=CLIPLoss().to(device)
|
||||
loss_fn_vgg = lpips.LPIPS(net='vgg').to(device)
|
||||
# vgg_loss=VggLoss().to(device)
|
||||
# summary(model, input_size = (3, 256, 256), batch_size = 5)
|
||||
# set_requires_grad(model.mlp.parameters())
|
||||
|
|
@ -73,12 +76,13 @@ def main(cfg: DictConfig) -> None:
|
|||
# base_code=base_code.to(device)
|
||||
detail_code=detail_code.to(device)
|
||||
# codes = model.net.encoder(inputs)
|
||||
generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
|
||||
generated_img,adv_latent_codes=model(inputs,detail_code)
|
||||
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
||||
optimizer.zero_grad()
|
||||
# generated_img,adv_latent_codes=model(inputs)
|
||||
# loss_vgg=vgg_loss(inputs,generated_img)
|
||||
loss_l1=F.l1_loss(base_code,adv_latent_codes)
|
||||
# loss_l1=F.l1_loss(base_code,adv_latent_codes)
|
||||
loss_vgg=loss_fn_vgg(inputs,generated_img)
|
||||
loss_clip=clip_loss(generated_img,prompt)
|
||||
adv_outputs = classifier(generated_img)
|
||||
clean_outputs=classifier(inputs)
|
||||
|
|
@ -90,7 +94,7 @@ def main(cfg: DictConfig) -> None:
|
|||
|
||||
# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||
# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||
loss=loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||
loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
|
@ -108,13 +112,13 @@ def main(cfg: DictConfig) -> None:
|
|||
running_loss = 0. #test_dataloader
|
||||
running_corrects = 0
|
||||
|
||||
for i, (inputs, labels,base_code,detail_code) in enumerate(test_dataloader):
|
||||
for i, (inputs, labels,detail_code) in enumerate(test_dataloader):
|
||||
inputs = inputs.to(device)
|
||||
labels = labels.to(device)
|
||||
base_code=base_code.to(device)
|
||||
# base_code=base_code.to(device)
|
||||
detail_code=detail_code.to(device)
|
||||
|
||||
generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
|
||||
generated_img,adv_latent_codes=model(inputs,detail_code)
|
||||
outputs = classifier(generated_img)
|
||||
preds=criterion(outputs ,labels)
|
||||
|
||||
|
|
|
|||
36
model.py
36
model.py
|
|
@ -77,11 +77,11 @@ class GanAttack(nn.Module):
|
|||
|
||||
|
||||
|
||||
def forward(self, img,basecode,detailcode):
|
||||
def forward(self, img,detailcode):
|
||||
batch_size=img.shape[0]
|
||||
prompt=self.prompt.repeat(batch_size,18,1).to(device)
|
||||
# print(prompt.shape)
|
||||
x_prompt=torch.cat([basecode,prompt],dim=2)
|
||||
x_prompt=torch.cat([detailcode,prompt],dim=2)
|
||||
x_prompt=self.mlp(x_prompt)
|
||||
x=x_prompt+x
|
||||
result_images=self.generator.synthesis(detailcode,randomize_noise=False,
|
||||
|
|
@ -106,24 +106,24 @@ class CLIPLoss(torch.nn.Module):
|
|||
return similarity
|
||||
|
||||
|
||||
class VggLoss(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(VggLoss, self).__init__()
|
||||
self.model=models.vgg11(pretrained=True)
|
||||
self.model.features=nn.Sequential()
|
||||
# class VggLoss(torch.nn.Module):
|
||||
# def __init__(self):
|
||||
# super(VggLoss, self).__init__()
|
||||
# self.model=models.vgg11(pretrained=True)
|
||||
# self.model.features=nn.Sequential()
|
||||
|
||||
# self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
|
||||
# self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
|
||||
# # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
|
||||
# # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
|
||||
|
||||
def forward(self, image1, image2):
|
||||
# image=normalize(image)
|
||||
with torch.no_grad:
|
||||
feature1=self.model(image1)
|
||||
feature2=self.model(image2)
|
||||
feature1=torch.flatten(feature1)
|
||||
feature2=torch.flatten(feature2)
|
||||
similarity = F.cosine_similarity(feature1,feature2)
|
||||
return similarity
|
||||
# def forward(self, image1, image2):
|
||||
# # image=normalize(image)
|
||||
# with torch.no_grad:
|
||||
# feature1=self.model(image1)
|
||||
# feature2=self.model(image2)
|
||||
# feature1=torch.flatten(feature1)
|
||||
# feature2=torch.flatten(feature2)
|
||||
# similarity = F.cosine_similarity(feature1,feature2)
|
||||
# return similarity
|
||||
|
||||
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||
def test(cfg):
|
||||
|
|
|
|||
|
|
@ -26,9 +26,12 @@ def save_prompt(cfg):
|
|||
spio.savemat("prompt.mat",{'prompt':prompt})
|
||||
|
||||
def get_prompt(cfg):
|
||||
data=spio.loadmat("prompt.mat")
|
||||
prompt=data['prompt']
|
||||
return torch.from_numpy(prompt).to(device)
|
||||
model, preprocess = clip.load("RN50", device=device)
|
||||
prompt=cfg.prompt
|
||||
text=clip.tokenize(prompt).to(device)
|
||||
with torch.no_grad():
|
||||
prompt = model.encode_text(text)
|
||||
return prompt
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue