145 lines
4.6 KiB
Python
145 lines
4.6 KiB
Python
import torch
|
|
import torchvision
|
|
from torchvision import datasets, models, transforms
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import os
|
|
import clip
|
|
# from pixel2style2pixel.models.psp import pSp
|
|
from argparse import Namespace
|
|
from utils import normalize
|
|
from stylegan.model import Generator
|
|
import hydra
|
|
from omegaconf import DictConfig, OmegaConf
|
|
import sys
|
|
import os
|
|
import numpy as np
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
def get_prompt(cfg):
|
|
model, preprocess = clip.load("RN50", device=device)
|
|
prompt=cfg.prompt
|
|
text=clip.tokenize(prompt).to(device)
|
|
with torch.no_grad():
|
|
prompt = model.encode_text(text)
|
|
return prompt
|
|
|
|
def get_stylegan_generator(cfg):
|
|
generator=Generator(1024, 512, 8)
|
|
checkpoint = torch.load(cfg.paths.stylegan, map_location=device)
|
|
generator.load_state_dict(checkpoint['g_ema'])
|
|
generator.to(device)
|
|
generator.eval()
|
|
return generator
|
|
|
|
|
|
# class GanAttack(nn.Module):
|
|
|
|
# def get_stylegan_inverter(cfg):
|
|
|
|
# # ensure_checkpoint_exists(ckpt_path)
|
|
# path=cfg.paths.inverter_cfg
|
|
# ckpt = torch.load(path, map_location='cuda:0')
|
|
# opts = ckpt['opts']
|
|
# opts['checkpoint_path'] = path
|
|
# if 'learn_in_w' not in opts:
|
|
# opts['learn_in_w'] = False
|
|
# if 'output_size' not in opts:
|
|
# opts['output_size'] = 1024
|
|
|
|
# net = pSp(Namespace(**opts))
|
|
# net.eval()
|
|
# net.cuda()
|
|
# return net
|
|
|
|
class GanAttack(nn.Module):
|
|
def __init__(self, cfg,prompt):
|
|
super().__init__()
|
|
|
|
# self.net= get_stylegan_inverter(cfg)
|
|
# self.generator.eval()
|
|
# self.inverter=inverter
|
|
# self.images_resize=images_resize
|
|
self.generator=get_stylegan_generator(cfg)
|
|
self.prompt=prompt
|
|
text_len=self.prompt.shape[1]
|
|
self.mlp=nn.Sequential(
|
|
nn.Linear(text_len+512, 4096),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(4096, 512)
|
|
)
|
|
self.noise_mlp=nn.Sequential(
|
|
nn.Linear(10,1024),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(1024,512)
|
|
)
|
|
self.noises=self.generator.make_noise()
|
|
for i,j in enumerate(self.noises):
|
|
if i>9:
|
|
j.detach()
|
|
|
|
|
|
|
|
def forward(self, img,detailcode):
|
|
batch_size=img.shape[0]
|
|
prompt=self.prompt.repeat(batch_size,18,1).to(device)
|
|
# print(prompt.shape)
|
|
x_prompt=torch.cat([detailcode,prompt],dim=2)
|
|
x_prompt=self.mlp(x_prompt)
|
|
x=x_prompt+x
|
|
adv=self.noise_mlp(self.noises)
|
|
self.noises=adv+self.noises
|
|
result_images, _ =self.generator([detailcode],input_is_latent=True,
|
|
randomize_noise=False,noises=self.noises)
|
|
|
|
|
|
return result_images,x
|
|
|
|
class CLIPLoss(torch.nn.Module):
|
|
def __init__(self):
|
|
super(CLIPLoss, self).__init__()
|
|
self.model, self.preprocess = clip.load("RN50", device="cuda")
|
|
self.model.eval()
|
|
self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
|
|
# self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
|
|
# self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
|
|
|
|
def forward(self, image, text):
|
|
image=normalize(image)
|
|
image = self.face_pool(image)
|
|
similarity = 1 - self.model(image, text)[0]/ 100
|
|
return similarity
|
|
|
|
|
|
# class VggLoss(torch.nn.Module):
|
|
# def __init__(self):
|
|
# super(VggLoss, self).__init__()
|
|
# self.model=models.vgg11(pretrained=True)
|
|
# self.model.features=nn.Sequential()
|
|
|
|
# # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
|
|
# # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
|
|
|
|
# def forward(self, image1, image2):
|
|
# # image=normalize(image)
|
|
# with torch.no_grad:
|
|
# feature1=self.model(image1)
|
|
# feature2=self.model(image2)
|
|
# feature1=torch.flatten(feature1)
|
|
# feature2=torch.flatten(feature2)
|
|
# similarity = F.cosine_similarity(feature1,feature2)
|
|
# return similarity
|
|
|
|
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
|
def test(cfg):
|
|
prompt=torch.randn([1,1024]).to(device)
|
|
model=GanAttack(cfg,prompt).to(device)
|
|
data=torch.randn([2,3,256,256]).to(device)
|
|
result_images,x=model(data)
|
|
print(result_images.shape)
|
|
print(x.shape)
|
|
if __name__ == "__main__":
|
|
test()
|
|
|