more change
This commit is contained in:
parent
98d940cbf9
commit
7f3ffa4807
18
GanAttack.py
18
GanAttack.py
|
|
@ -7,7 +7,8 @@ from torchvision import models
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from data.dataset import get_dataset,get_adv_dataset
|
from data.dataset import get_dataset,get_adv_dataset
|
||||||
from utils import get_model,set_requires_grad,unnormalize
|
from utils import get_model,set_requires_grad,unnormalize
|
||||||
from model import GanAttack,CLIPLoss,VggLoss
|
from model import GanAttack,CLIPLoss
|
||||||
|
import lpips
|
||||||
from prompt import get_prompt
|
from prompt import get_prompt
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import time
|
import time
|
||||||
|
|
@ -45,6 +46,7 @@ def main(cfg: DictConfig) -> None:
|
||||||
prompt=get_prompt(cfg)
|
prompt=get_prompt(cfg)
|
||||||
|
|
||||||
# net=get_stylegan_inverter(cfg)
|
# net=get_stylegan_inverter(cfg)
|
||||||
|
|
||||||
model=GanAttack(cfg,prompt).to(device)
|
model=GanAttack(cfg,prompt).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -53,6 +55,7 @@ def main(cfg: DictConfig) -> None:
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
max_loss=nn.MarginRankingLoss(0.1)
|
max_loss=nn.MarginRankingLoss(0.1)
|
||||||
clip_loss=CLIPLoss().to(device)
|
clip_loss=CLIPLoss().to(device)
|
||||||
|
loss_fn_vgg = lpips.LPIPS(net='vgg').to(device)
|
||||||
# vgg_loss=VggLoss().to(device)
|
# vgg_loss=VggLoss().to(device)
|
||||||
# summary(model, input_size = (3, 256, 256), batch_size = 5)
|
# summary(model, input_size = (3, 256, 256), batch_size = 5)
|
||||||
# set_requires_grad(model.mlp.parameters())
|
# set_requires_grad(model.mlp.parameters())
|
||||||
|
|
@ -73,12 +76,13 @@ def main(cfg: DictConfig) -> None:
|
||||||
# base_code=base_code.to(device)
|
# base_code=base_code.to(device)
|
||||||
detail_code=detail_code.to(device)
|
detail_code=detail_code.to(device)
|
||||||
# codes = model.net.encoder(inputs)
|
# codes = model.net.encoder(inputs)
|
||||||
generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
|
generated_img,adv_latent_codes=model(inputs,detail_code)
|
||||||
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
# generated_img,adv_latent_codes=model(inputs)
|
# generated_img,adv_latent_codes=model(inputs)
|
||||||
# loss_vgg=vgg_loss(inputs,generated_img)
|
# loss_vgg=vgg_loss(inputs,generated_img)
|
||||||
loss_l1=F.l1_loss(base_code,adv_latent_codes)
|
# loss_l1=F.l1_loss(base_code,adv_latent_codes)
|
||||||
|
loss_vgg=loss_fn_vgg(inputs,generated_img)
|
||||||
loss_clip=clip_loss(generated_img,prompt)
|
loss_clip=clip_loss(generated_img,prompt)
|
||||||
adv_outputs = classifier(generated_img)
|
adv_outputs = classifier(generated_img)
|
||||||
clean_outputs=classifier(inputs)
|
clean_outputs=classifier(inputs)
|
||||||
|
|
@ -90,7 +94,7 @@ def main(cfg: DictConfig) -> None:
|
||||||
|
|
||||||
# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||||
# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||||
loss=loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
@ -108,13 +112,13 @@ def main(cfg: DictConfig) -> None:
|
||||||
running_loss = 0. #test_dataloader
|
running_loss = 0. #test_dataloader
|
||||||
running_corrects = 0
|
running_corrects = 0
|
||||||
|
|
||||||
for i, (inputs, labels,base_code,detail_code) in enumerate(test_dataloader):
|
for i, (inputs, labels,detail_code) in enumerate(test_dataloader):
|
||||||
inputs = inputs.to(device)
|
inputs = inputs.to(device)
|
||||||
labels = labels.to(device)
|
labels = labels.to(device)
|
||||||
base_code=base_code.to(device)
|
# base_code=base_code.to(device)
|
||||||
detail_code=detail_code.to(device)
|
detail_code=detail_code.to(device)
|
||||||
|
|
||||||
generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
|
generated_img,adv_latent_codes=model(inputs,detail_code)
|
||||||
outputs = classifier(generated_img)
|
outputs = classifier(generated_img)
|
||||||
preds=criterion(outputs ,labels)
|
preds=criterion(outputs ,labels)
|
||||||
|
|
||||||
|
|
|
||||||
36
model.py
36
model.py
|
|
@ -77,11 +77,11 @@ class GanAttack(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, img,basecode,detailcode):
|
def forward(self, img,detailcode):
|
||||||
batch_size=img.shape[0]
|
batch_size=img.shape[0]
|
||||||
prompt=self.prompt.repeat(batch_size,18,1).to(device)
|
prompt=self.prompt.repeat(batch_size,18,1).to(device)
|
||||||
# print(prompt.shape)
|
# print(prompt.shape)
|
||||||
x_prompt=torch.cat([basecode,prompt],dim=2)
|
x_prompt=torch.cat([detailcode,prompt],dim=2)
|
||||||
x_prompt=self.mlp(x_prompt)
|
x_prompt=self.mlp(x_prompt)
|
||||||
x=x_prompt+x
|
x=x_prompt+x
|
||||||
result_images=self.generator.synthesis(detailcode,randomize_noise=False,
|
result_images=self.generator.synthesis(detailcode,randomize_noise=False,
|
||||||
|
|
@ -106,24 +106,24 @@ class CLIPLoss(torch.nn.Module):
|
||||||
return similarity
|
return similarity
|
||||||
|
|
||||||
|
|
||||||
class VggLoss(torch.nn.Module):
|
# class VggLoss(torch.nn.Module):
|
||||||
def __init__(self):
|
# def __init__(self):
|
||||||
super(VggLoss, self).__init__()
|
# super(VggLoss, self).__init__()
|
||||||
self.model=models.vgg11(pretrained=True)
|
# self.model=models.vgg11(pretrained=True)
|
||||||
self.model.features=nn.Sequential()
|
# self.model.features=nn.Sequential()
|
||||||
|
|
||||||
# self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
|
# # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
|
||||||
# self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
|
# # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
|
||||||
|
|
||||||
def forward(self, image1, image2):
|
# def forward(self, image1, image2):
|
||||||
# image=normalize(image)
|
# # image=normalize(image)
|
||||||
with torch.no_grad:
|
# with torch.no_grad:
|
||||||
feature1=self.model(image1)
|
# feature1=self.model(image1)
|
||||||
feature2=self.model(image2)
|
# feature2=self.model(image2)
|
||||||
feature1=torch.flatten(feature1)
|
# feature1=torch.flatten(feature1)
|
||||||
feature2=torch.flatten(feature2)
|
# feature2=torch.flatten(feature2)
|
||||||
similarity = F.cosine_similarity(feature1,feature2)
|
# similarity = F.cosine_similarity(feature1,feature2)
|
||||||
return similarity
|
# return similarity
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||||
def test(cfg):
|
def test(cfg):
|
||||||
|
|
|
||||||
|
|
@ -26,9 +26,12 @@ def save_prompt(cfg):
|
||||||
spio.savemat("prompt.mat",{'prompt':prompt})
|
spio.savemat("prompt.mat",{'prompt':prompt})
|
||||||
|
|
||||||
def get_prompt(cfg):
|
def get_prompt(cfg):
|
||||||
data=spio.loadmat("prompt.mat")
|
model, preprocess = clip.load("RN50", device=device)
|
||||||
prompt=data['prompt']
|
prompt=cfg.prompt
|
||||||
return torch.from_numpy(prompt).to(device)
|
text=clip.tokenize(prompt).to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
prompt = model.encode_text(text)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue